mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/main-intel.Dockerfile # .devops/main-vulkan.Dockerfile # .devops/server-intel.Dockerfile # .devops/server-vulkan.Dockerfile # .github/workflows/bench.yml # .github/workflows/build.yml # .github/workflows/python-lint.yml # .github/workflows/server.yml # .gitignore # Makefile # README-sycl.md # README.md # ci/run.sh # flake.lock # llama.cpp # models/ggml-vocab-falcon.gguf # models/ggml-vocab-llama-spm.gguf # models/ggml-vocab-mpt.gguf # models/ggml-vocab-stablelm.gguf # models/ggml-vocab-starcoder.gguf # requirements.txt # scripts/check-requirements.sh # tests/CMakeLists.txt # tests/test-backend-ops.cpp # tests/test-grammar-integration.cpp # tests/test-tokenizer-0-bpe.py # tests/test-tokenizer-0-spm.py # tests/test-tokenizer-1-spm.cpp
This commit is contained in:
commit
17a24d753c
52 changed files with 4978 additions and 1249 deletions
|
@ -68,7 +68,6 @@
|
||||||
#include <sys/syslimits.h>
|
#include <sys/syslimits.h>
|
||||||
#endif
|
#endif
|
||||||
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
||||||
#define LLAMA_CURL_MAX_HEADER_LENGTH 256
|
|
||||||
#endif // LLAMA_USE_CURL
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
@ -235,6 +234,52 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
||||||
|
const char * sep = strchr(data, '=');
|
||||||
|
if (sep == nullptr || sep - data >= 128) {
|
||||||
|
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strncpy(kvo.key, data, sep - data);
|
||||||
|
kvo.key[sep - data] = 0;
|
||||||
|
sep++;
|
||||||
|
if (strncmp(sep, "int:", 4) == 0) {
|
||||||
|
sep += 4;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||||
|
kvo.val_i64 = std::atol(sep);
|
||||||
|
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||||
|
sep += 6;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||||
|
kvo.val_f64 = std::atof(sep);
|
||||||
|
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||||
|
sep += 5;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||||
|
if (std::strcmp(sep, "true") == 0) {
|
||||||
|
kvo.val_bool = true;
|
||||||
|
} else if (std::strcmp(sep, "false") == 0) {
|
||||||
|
kvo.val_bool = false;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (strncmp(sep, "str:", 4) == 0) {
|
||||||
|
sep += 4;
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||||
|
if (strlen(sep) > 127) {
|
||||||
|
fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
strncpy(kvo.val_str, sep, 127);
|
||||||
|
kvo.val_str[127] = '\0';
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
overrides.emplace_back(std::move(kvo));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||||
llama_sampling_params & sparams = params.sparams;
|
llama_sampling_params & sparams = params.sparams;
|
||||||
|
|
||||||
|
@ -848,7 +893,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
params.image = argv[i];
|
params.image.emplace_back(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (arg == "-i" || arg == "--interactive") {
|
if (arg == "-i" || arg == "--interactive") {
|
||||||
|
@ -903,6 +948,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-fa" || arg == "--flash-attn") {
|
||||||
|
params.flash_attn = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--color") {
|
if (arg == "--color") {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
return true;
|
return true;
|
||||||
|
@ -1090,6 +1139,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.n_print = std::stoi(argv[i]);
|
params.n_print = std::stoi(argv[i]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--check-tensors") {
|
||||||
|
params.check_tensors = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--ppl-output-type") {
|
if (arg == "--ppl-output-type") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1241,47 +1294,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
char* sep = strchr(argv[i], '=');
|
if (!parse_kv_override(argv[i], params.kv_overrides)) {
|
||||||
if (sep == nullptr || sep - argv[i] >= 128) {
|
|
||||||
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
struct llama_model_kv_override kvo;
|
|
||||||
std::strncpy(kvo.key, argv[i], sep - argv[i]);
|
|
||||||
kvo.key[sep - argv[i]] = 0;
|
|
||||||
sep++;
|
|
||||||
if (strncmp(sep, "int:", 4) == 0) {
|
|
||||||
sep += 4;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
|
||||||
kvo.int_value = std::atol(sep);
|
|
||||||
}
|
|
||||||
else if (strncmp(sep, "float:", 6) == 0) {
|
|
||||||
sep += 6;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
|
||||||
kvo.float_value = std::atof(sep);
|
|
||||||
}
|
|
||||||
else if (strncmp(sep, "bool:", 5) == 0) {
|
|
||||||
sep += 5;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
|
||||||
if (std::strcmp(sep, "true") == 0) {
|
|
||||||
kvo.bool_value = true;
|
|
||||||
}
|
|
||||||
else if (std::strcmp(sep, "false") == 0) {
|
|
||||||
kvo.bool_value = false;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
|
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
params.kv_overrides.push_back(kvo);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
@ -1311,6 +1328,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gpt_params_handle_model_default(gpt_params & params) {
|
||||||
|
if (!params.hf_repo.empty()) {
|
||||||
|
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||||
|
if (params.hf_file.empty()) {
|
||||||
|
if (params.model.empty()) {
|
||||||
|
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
|
||||||
|
}
|
||||||
|
params.hf_file = params.model;
|
||||||
|
} else if (params.model.empty()) {
|
||||||
|
params.model = "models/" + string_split(params.hf_file, '/').back();
|
||||||
|
}
|
||||||
|
} else if (!params.model_url.empty()) {
|
||||||
|
if (params.model.empty()) {
|
||||||
|
auto f = string_split(params.model_url, '#').front();
|
||||||
|
f = string_split(f, '?').front();
|
||||||
|
f = string_split(f, '/').back();
|
||||||
|
params.model = "models/" + f;
|
||||||
|
}
|
||||||
|
} else if (params.model.empty()) {
|
||||||
|
params.model = DEFAULT_MODEL_PATH;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
bool invalid_param = false;
|
bool invalid_param = false;
|
||||||
std::string arg;
|
std::string arg;
|
||||||
|
@ -1339,10 +1379,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
gpt_params_handle_model_default(params);
|
||||||
if (!params.hf_repo.empty() && params.hf_file.empty()) {
|
|
||||||
params.hf_file = params.model;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.escape) {
|
if (params.escape) {
|
||||||
process_escapes(params.prompt);
|
process_escapes(params.prompt);
|
||||||
|
@ -1481,8 +1518,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||||
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
|
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
|
||||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
||||||
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
|
printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n");
|
||||||
if (llama_supports_mlock()) {
|
if (llama_supports_mlock()) {
|
||||||
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||||
}
|
}
|
||||||
|
@ -1535,7 +1573,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" --control-vector-layer-range START END\n");
|
printf(" --control-vector-layer-range START END\n");
|
||||||
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
|
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
|
||||||
printf(" -m FNAME, --model FNAME\n");
|
printf(" -m FNAME, --model FNAME\n");
|
||||||
printf(" model path (default: %s)\n", params.model.c_str());
|
printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH);
|
||||||
printf(" -md FNAME, --model-draft FNAME\n");
|
printf(" -md FNAME, --model-draft FNAME\n");
|
||||||
printf(" draft model for speculative decoding (default: unused)\n");
|
printf(" draft model for speculative decoding (default: unused)\n");
|
||||||
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
|
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
|
||||||
|
@ -1552,9 +1590,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
|
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||||
printf(" -ptc N, --print-token-count N\n");
|
printf(" -ptc N, --print-token-count N\n");
|
||||||
printf(" print token count every N tokens (default: %d)\n", params.n_print);
|
printf(" print token count every N tokens (default: %d)\n", params.n_print);
|
||||||
|
printf(" --check-tensors check model tensor data for invalid values\n");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_print_usage();
|
log_print_usage();
|
||||||
|
@ -1679,6 +1718,18 @@ std::vector<std::string> string_split(std::string input, char separator) {
|
||||||
return parts;
|
return parts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string string_strip(const std::string & str) {
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.size();
|
||||||
|
while (start < end && std::isspace(str[start])) {
|
||||||
|
start++;
|
||||||
|
}
|
||||||
|
while (end > start && std::isspace(str[end - 1])) {
|
||||||
|
end--;
|
||||||
|
}
|
||||||
|
return str.substr(start, end - start);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
||||||
{"top_k", llama_sampler_type::TOP_K},
|
{"top_k", llama_sampler_type::TOP_K},
|
||||||
|
@ -1775,6 +1826,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||||
mparams.tensor_split = params.tensor_split;
|
mparams.tensor_split = params.tensor_split;
|
||||||
mparams.use_mmap = params.use_mmap;
|
mparams.use_mmap = params.use_mmap;
|
||||||
mparams.use_mlock = params.use_mlock;
|
mparams.use_mlock = params.use_mlock;
|
||||||
|
mparams.check_tensors = params.check_tensors;
|
||||||
if (params.kv_overrides.empty()) {
|
if (params.kv_overrides.empty()) {
|
||||||
mparams.kv_overrides = NULL;
|
mparams.kv_overrides = NULL;
|
||||||
} else {
|
} else {
|
||||||
|
@ -1839,6 +1891,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
cparams.cb_eval = params.cb_eval;
|
cparams.cb_eval = params.cb_eval;
|
||||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||||
cparams.offload_kqv = !params.no_kv_offload;
|
cparams.offload_kqv = !params.no_kv_offload;
|
||||||
|
cparams.flash_attn = params.flash_attn;
|
||||||
|
|
||||||
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
|
||||||
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
|
||||||
|
@ -1869,59 +1922,75 @@ void llama_batch_add(
|
||||||
|
|
||||||
#ifdef LLAMA_USE_CURL
|
#ifdef LLAMA_USE_CURL
|
||||||
|
|
||||||
static bool llama_download_file(CURL * curl, const char * url, const char * path) {
|
static bool starts_with(const std::string & str, const std::string & prefix) {
|
||||||
|
// While we wait for C++20's std::string::starts_with...
|
||||||
|
return str.rfind(prefix, 0) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool llama_download_file(const std::string & url, const std::string & path) {
|
||||||
|
|
||||||
|
// Initialize libcurl
|
||||||
|
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
|
||||||
|
if (!curl) {
|
||||||
|
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
bool force_download = false;
|
bool force_download = false;
|
||||||
|
|
||||||
// Set the URL, allow to follow http redirection
|
// Set the URL, allow to follow http redirection
|
||||||
curl_easy_setopt(curl, CURLOPT_URL, url);
|
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
||||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
||||||
// operating system. Currently implemented under MS-Windows.
|
// operating system. Currently implemented under MS-Windows.
|
||||||
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Check if the file already exists locally
|
// Check if the file already exists locally
|
||||||
struct stat model_file_info;
|
struct stat model_file_info;
|
||||||
auto file_exists = (stat(path, &model_file_info) == 0);
|
auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
|
||||||
|
|
||||||
// If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files
|
// If the file exists, check its JSON metadata companion file.
|
||||||
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
|
std::string metadata_path = path + ".json";
|
||||||
char etag_path[PATH_MAX] = {0};
|
nlohmann::json metadata;
|
||||||
snprintf(etag_path, sizeof(etag_path), "%s.etag", path);
|
std::string etag;
|
||||||
|
std::string last_modified;
|
||||||
char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
|
|
||||||
char last_modified_path[PATH_MAX] = {0};
|
|
||||||
snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path);
|
|
||||||
|
|
||||||
if (file_exists) {
|
if (file_exists) {
|
||||||
auto * f_etag = fopen(etag_path, "r");
|
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||||
if (f_etag) {
|
std::ifstream metadata_in(metadata_path);
|
||||||
if (!fgets(etag, sizeof(etag), f_etag)) {
|
if (metadata_in.good()) {
|
||||||
fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path);
|
try {
|
||||||
|
metadata_in >> metadata;
|
||||||
|
fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
||||||
|
if (metadata.contains("url") && metadata["url"].is_string()) {
|
||||||
|
auto previous_url = metadata["url"].get<std::string>();
|
||||||
|
if (previous_url != url) {
|
||||||
|
fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (metadata.contains("etag") && metadata["etag"].is_string()) {
|
||||||
|
etag = metadata["etag"];
|
||||||
|
}
|
||||||
|
if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) {
|
||||||
|
last_modified = metadata["lastModified"];
|
||||||
|
}
|
||||||
|
} catch (const nlohmann::json::exception & e) {
|
||||||
|
fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag);
|
fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str());
|
||||||
}
|
|
||||||
fclose(f_etag);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * f_last_modified = fopen(last_modified_path, "r");
|
|
||||||
if (f_last_modified) {
|
|
||||||
if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) {
|
|
||||||
fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path,
|
|
||||||
last_modified);
|
|
||||||
}
|
|
||||||
fclose(f_last_modified);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a HEAD request to retrieve the etag and last-modified headers
|
// Send a HEAD request to retrieve the etag and last-modified headers
|
||||||
struct llama_load_model_from_url_headers {
|
struct llama_load_model_from_url_headers {
|
||||||
char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
|
std::string etag;
|
||||||
char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0};
|
std::string last_modified;
|
||||||
};
|
};
|
||||||
llama_load_model_from_url_headers headers;
|
llama_load_model_from_url_headers headers;
|
||||||
{
|
{
|
||||||
|
@ -1929,38 +1998,37 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
|
||||||
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
||||||
llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
|
llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata;
|
||||||
|
|
||||||
// Convert header field name to lowercase
|
static std::regex header_regex("([^:]+): (.*)\r\n");
|
||||||
for (size_t i = 0; i < n_items && buffer[i] != ':'; ++i) {
|
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
||||||
buffer[i] = tolower(buffer[i]);
|
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
||||||
}
|
|
||||||
|
|
||||||
const char * etag_prefix = "etag: ";
|
std::string header(buffer, n_items);
|
||||||
if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) {
|
std::smatch match;
|
||||||
strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF
|
if (std::regex_match(header, match, header_regex)) {
|
||||||
|
const std::string & key = match[1];
|
||||||
|
const std::string & value = match[2];
|
||||||
|
if (std::regex_match(key, match, etag_regex)) {
|
||||||
|
headers->etag = value;
|
||||||
|
} else if (std::regex_match(key, match, last_modified_regex)) {
|
||||||
|
headers->last_modified = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * last_modified_prefix = "last-modified: ";
|
|
||||||
if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) {
|
|
||||||
strncpy(headers->last_modified, buffer + strlen(last_modified_prefix),
|
|
||||||
n_items - strlen(last_modified_prefix) - 2); // Remove CRLF
|
|
||||||
}
|
}
|
||||||
return n_items;
|
return n_items;
|
||||||
};
|
};
|
||||||
|
|
||||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
||||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
||||||
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
||||||
curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers);
|
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
||||||
|
|
||||||
CURLcode res = curl_easy_perform(curl);
|
CURLcode res = curl_easy_perform(curl.get());
|
||||||
if (res != CURLE_OK) {
|
if (res != CURLE_OK) {
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
|
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
long http_code = 0;
|
long http_code = 0;
|
||||||
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
|
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||||
if (http_code != 200) {
|
if (http_code != 200) {
|
||||||
// HEAD not supported, we don't know if the file has changed
|
// HEAD not supported, we don't know if the file has changed
|
||||||
// force trigger downloading
|
// force trigger downloading
|
||||||
|
@ -1969,28 +2037,30 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the ETag or the Last-Modified headers are different: trigger a new download
|
bool should_download = !file_exists || force_download;
|
||||||
bool should_download = !file_exists
|
if (!should_download) {
|
||||||
|| force_download
|
if (!etag.empty() && etag != headers.etag) {
|
||||||
|| (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0)
|
fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
||||||
|| (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0);
|
should_download = true;
|
||||||
|
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||||
|
fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
||||||
|
should_download = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (should_download) {
|
if (should_download) {
|
||||||
char path_temporary[PATH_MAX] = {0};
|
std::string path_temporary = path + ".downloadInProgress";
|
||||||
snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path);
|
|
||||||
if (file_exists) {
|
if (file_exists) {
|
||||||
fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path);
|
fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||||
if (remove(path) != 0) {
|
if (remove(path.c_str()) != 0) {
|
||||||
curl_easy_cleanup(curl);
|
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||||
fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the output file
|
// Set the output file
|
||||||
auto * outfile = fopen(path_temporary, "wb");
|
std::unique_ptr<FILE, decltype(&fclose)> outfile(fopen(path_temporary.c_str(), "wb"), fclose);
|
||||||
if (!outfile) {
|
if (!outfile) {
|
||||||
curl_easy_cleanup(curl);
|
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
||||||
fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1998,12 +2068,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
|
||||||
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
||||||
return fwrite(data, size, nmemb, (FILE *)fd);
|
return fwrite(data, size, nmemb, (FILE *)fd);
|
||||||
};
|
};
|
||||||
curl_easy_setopt(curl, CURLOPT_NOBODY, 0L);
|
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
||||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
||||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile);
|
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
||||||
|
|
||||||
// display download progress
|
// display download progress
|
||||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
|
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
||||||
|
|
||||||
// helper function to hide password in URL
|
// helper function to hide password in URL
|
||||||
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
||||||
|
@ -2022,51 +2092,34 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
|
||||||
|
|
||||||
// start the download
|
// start the download
|
||||||
fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
||||||
llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified);
|
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
||||||
auto res = curl_easy_perform(curl);
|
auto res = curl_easy_perform(curl.get());
|
||||||
if (res != CURLE_OK) {
|
if (res != CURLE_OK) {
|
||||||
fclose(outfile);
|
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
|
fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
long http_code = 0;
|
long http_code = 0;
|
||||||
curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code);
|
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
||||||
if (http_code < 200 || http_code >= 400) {
|
if (http_code < 200 || http_code >= 400) {
|
||||||
fclose(outfile);
|
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
|
fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up
|
// Causes file to be closed explicitly here before we rename it.
|
||||||
fclose(outfile);
|
outfile.reset();
|
||||||
|
|
||||||
// Write the new ETag to the .etag file
|
// Write the updated JSON metadata file.
|
||||||
if (strlen(headers.etag) > 0) {
|
metadata.update({
|
||||||
auto * etag_file = fopen(etag_path, "w");
|
{"url", url},
|
||||||
if (etag_file) {
|
{"etag", headers.etag},
|
||||||
fputs(headers.etag, etag_file);
|
{"lastModified", headers.last_modified}
|
||||||
fclose(etag_file);
|
});
|
||||||
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
|
std::ofstream(metadata_path) << metadata.dump(4);
|
||||||
}
|
fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||||
}
|
|
||||||
|
|
||||||
// Write the new lastModified to the .etag file
|
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||||
if (strlen(headers.last_modified) > 0) {
|
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||||
auto * last_modified_file = fopen(last_modified_path, "w");
|
|
||||||
if (last_modified_file) {
|
|
||||||
fputs(headers.last_modified, last_modified_file);
|
|
||||||
fclose(last_modified_file);
|
|
||||||
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
|
|
||||||
headers.last_modified);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rename(path_temporary, path) != 0) {
|
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2084,15 +2137,7 @@ struct llama_model * llama_load_model_from_url(
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize libcurl
|
if (!llama_download_file(model_url, path_model)) {
|
||||||
auto * curl = curl_easy_init();
|
|
||||||
|
|
||||||
if (!curl) {
|
|
||||||
fprintf(stderr, "%s: error initializing libcurl\n", __func__);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!llama_download_file(curl, model_url, path_model)) {
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2106,7 +2151,6 @@ struct llama_model * llama_load_model_from_url(
|
||||||
auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
|
auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params);
|
||||||
if (!ctx_gguf) {
|
if (!ctx_gguf) {
|
||||||
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model);
|
fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model);
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2118,8 +2162,6 @@ struct llama_model * llama_load_model_from_url(
|
||||||
gguf_free(ctx_gguf);
|
gguf_free(ctx_gguf);
|
||||||
}
|
}
|
||||||
|
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
|
|
||||||
if (n_split > 1) {
|
if (n_split > 1) {
|
||||||
char split_prefix[PATH_MAX] = {0};
|
char split_prefix[PATH_MAX] = {0};
|
||||||
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||||
|
@ -2150,11 +2192,7 @@ struct llama_model * llama_load_model_from_url(
|
||||||
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
||||||
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
||||||
|
|
||||||
auto * curl = curl_easy_init();
|
return llama_download_file(split_url, split_path);
|
||||||
bool res = llama_download_file(curl, split_url, split_path);
|
|
||||||
curl_easy_cleanup(curl);
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}, idx));
|
}, idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2641,7 +2679,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||||
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
|
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
|
||||||
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
||||||
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
||||||
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
|
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
|
||||||
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
|
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
|
||||||
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
||||||
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
|
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
|
||||||
|
@ -2676,6 +2714,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
||||||
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
|
||||||
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
|
||||||
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
|
||||||
|
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
|
||||||
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
|
||||||
|
|
||||||
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
|
||||||
|
|
|
@ -31,6 +31,8 @@
|
||||||
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||||
} while(0)
|
} while(0)
|
||||||
|
|
||||||
|
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||||
|
|
||||||
// build info
|
// build info
|
||||||
|
|
||||||
struct llama_control_vector_load_info;
|
struct llama_control_vector_load_info;
|
||||||
|
@ -108,7 +110,7 @@ struct gpt_params {
|
||||||
// // sampling parameters
|
// // sampling parameters
|
||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
|
|
||||||
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
std::string model = ""; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
std::string model_url = ""; // model url to download
|
std::string model_url = ""; // model url to download
|
||||||
|
@ -164,6 +166,7 @@ struct gpt_params {
|
||||||
bool multiline_input = false; // reverse the usage of `\`
|
bool multiline_input = false; // reverse the usage of `\`
|
||||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
|
bool flash_attn = false; // flash attention
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool ignore_eos = false; // ignore generated EOS tokens
|
bool ignore_eos = false; // ignore generated EOS tokens
|
||||||
|
@ -177,15 +180,20 @@ struct gpt_params {
|
||||||
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
||||||
bool no_kv_offload = false; // disable KV offloading
|
bool no_kv_offload = false; // disable KV offloading
|
||||||
bool warmup = true; // warmup run
|
bool warmup = true; // warmup run
|
||||||
|
bool check_tensors = false; // validate tensor data
|
||||||
|
|
||||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||||
std::string cache_type_v = "f16"; // KV cache data type for the V
|
std::string cache_type_v = "f16"; // KV cache data type for the V
|
||||||
|
|
||||||
// multimodal models (see examples/llava)
|
// multimodal models (see examples/llava)
|
||||||
std::string mmproj = ""; // path to multimodal projector
|
std::string mmproj = ""; // path to multimodal projector
|
||||||
std::string image = ""; // path to an image file
|
std::vector<std::string> image; // path to image file(s)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void gpt_params_handle_model_default(gpt_params & params);
|
||||||
|
|
||||||
|
bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
|
|
||||||
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
@ -209,6 +217,7 @@ bool validate_file_name(const std::string & filename);
|
||||||
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
|
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
|
||||||
std::vector<std::string> string_split(std::string input, char separator);
|
std::vector<std::string> string_split(std::string input, char separator);
|
||||||
|
std::string string_strip(const std::string & str);
|
||||||
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
|
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -234,7 +234,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std::
|
||||||
// INTERNAL, DO NOT USE
|
// INTERNAL, DO NOT USE
|
||||||
// USE LOG() INSTEAD
|
// USE LOG() INSTEAD
|
||||||
//
|
//
|
||||||
#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER)
|
#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER)
|
||||||
#define LOG_IMPL(str, ...) \
|
#define LOG_IMPL(str, ...) \
|
||||||
do { \
|
do { \
|
||||||
if (LOG_TARGET != nullptr) \
|
if (LOG_TARGET != nullptr) \
|
||||||
|
@ -257,7 +257,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std::
|
||||||
// INTERNAL, DO NOT USE
|
// INTERNAL, DO NOT USE
|
||||||
// USE LOG_TEE() INSTEAD
|
// USE LOG_TEE() INSTEAD
|
||||||
//
|
//
|
||||||
#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER)
|
#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER)
|
||||||
#define LOG_TEE_IMPL(str, ...) \
|
#define LOG_TEE_IMPL(str, ...) \
|
||||||
do { \
|
do { \
|
||||||
if (LOG_TARGET != nullptr) \
|
if (LOG_TARGET != nullptr) \
|
||||||
|
|
|
@ -68,7 +68,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
|
|
||||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||||
if (seed == LLAMA_DEFAULT_SEED) {
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
seed = time(NULL);
|
seed = std::random_device{}();
|
||||||
}
|
}
|
||||||
ctx->rng.seed(seed);
|
ctx->rng.seed(seed);
|
||||||
}
|
}
|
||||||
|
|
279
convert-hf-to-gguf-update.py
Normal file
279
convert-hf-to-gguf-update.py
Normal file
|
@ -0,0 +1,279 @@
|
||||||
|
# This script downloads the tokenizer models of the specified models from Huggingface and
|
||||||
|
# generates the get_vocab_base_pre() function for convert-hf-to-gguf.py
|
||||||
|
#
|
||||||
|
# This is necessary in order to analyze the type of pre-tokenizer used by the model and
|
||||||
|
# provide the necessary information to llama.cpp via the GGUF header in order to implement
|
||||||
|
# the same pre-tokenizer.
|
||||||
|
#
|
||||||
|
# ref: https://github.com/ggerganov/llama.cpp/pull/6920
|
||||||
|
#
|
||||||
|
# Instructions:
|
||||||
|
#
|
||||||
|
# - Add a new model to the "models" list
|
||||||
|
# - Run the script with your huggingface token:
|
||||||
|
#
|
||||||
|
# python3 convert-hf-to-gguf-update.py <huggingface_token>
|
||||||
|
#
|
||||||
|
# - Copy-paste the generated get_vocab_base_pre() function into convert-hf-to-gguf.py
|
||||||
|
# - Update llama.cpp with the new pre-tokenizer if necessary
|
||||||
|
#
|
||||||
|
# TODO: generate tokenizer tests for llama.cpp
|
||||||
|
# TODO: automate the update of convert-hf-to-gguf.py
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
from hashlib import sha256
|
||||||
|
from enum import IntEnum, auto
|
||||||
|
|
||||||
|
class TOKENIZER_TYPE(IntEnum):
|
||||||
|
SPM = auto()
|
||||||
|
BPE = auto()
|
||||||
|
WPM = auto()
|
||||||
|
|
||||||
|
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
|
||||||
|
# will be updated with time - contributions welcome
|
||||||
|
chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
|
||||||
|
|
||||||
|
if len(sys.argv) == 2:
|
||||||
|
token = sys.argv[1]
|
||||||
|
else:
|
||||||
|
print("Usage: python convert-hf-to-gguf-update.py <huggingface_token>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# TODO: add models here, base models preferred
|
||||||
|
models = [
|
||||||
|
{ "name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", },
|
||||||
|
{ "name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", },
|
||||||
|
{ "name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", },
|
||||||
|
{ "name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", },
|
||||||
|
{ "name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", },
|
||||||
|
{ "name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", },
|
||||||
|
{ "name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", },
|
||||||
|
{ "name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", },
|
||||||
|
{ "name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", },
|
||||||
|
{ "name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", },
|
||||||
|
]
|
||||||
|
|
||||||
|
# make directory "models/tokenizers" if it doesn't exist
|
||||||
|
if not os.path.exists("models/tokenizers"):
|
||||||
|
os.makedirs("models/tokenizers")
|
||||||
|
|
||||||
|
def download_file_with_auth(url, token, save_path):
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
f.write(response.content)
|
||||||
|
print(f"File {save_path} downloaded successfully")
|
||||||
|
else:
|
||||||
|
print(f"Failed to download file. Status code: {response.status_code}")
|
||||||
|
|
||||||
|
# download the tokenizer models
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
repo = model["repo"]
|
||||||
|
tokt = model["tokt"]
|
||||||
|
|
||||||
|
if not os.path.exists(f"models/tokenizers/{name}"):
|
||||||
|
os.makedirs(f"models/tokenizers/{name}")
|
||||||
|
else:
|
||||||
|
print(f"Directory models/tokenizers/{name} already exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Downloading {name} to models/tokenizers/{name}")
|
||||||
|
|
||||||
|
url = f"{repo}/raw/main/config.json"
|
||||||
|
save_path = f"models/tokenizers/{name}/config.json"
|
||||||
|
download_file_with_auth(url, token, save_path)
|
||||||
|
|
||||||
|
url = f"{repo}/raw/main/tokenizer.json"
|
||||||
|
save_path = f"models/tokenizers/{name}/tokenizer.json"
|
||||||
|
download_file_with_auth(url, token, save_path)
|
||||||
|
|
||||||
|
if tokt == TOKENIZER_TYPE.SPM:
|
||||||
|
url = f"{repo}/resolve/main/tokenizer.model"
|
||||||
|
save_path = f"models/tokenizers/{name}/tokenizer.model"
|
||||||
|
download_file_with_auth(url, token, save_path)
|
||||||
|
|
||||||
|
url = f"{repo}/raw/main/tokenizer_config.json"
|
||||||
|
save_path = f"models/tokenizers/{name}/tokenizer_config.json"
|
||||||
|
download_file_with_auth(url, token, save_path)
|
||||||
|
|
||||||
|
# generate the source code for the convert-hf-to-gguf.py:get_vocab_base_pre() function:
|
||||||
|
# TODO: auto-update convert-hf-to-gguf.py with the generated function
|
||||||
|
|
||||||
|
src_ifs = ""
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
tokt = model["tokt"]
|
||||||
|
|
||||||
|
if tokt == TOKENIZER_TYPE.SPM:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# create the tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||||
|
|
||||||
|
chktok = tokenizer.encode(chktxt)
|
||||||
|
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||||
|
|
||||||
|
print(f"model: {name}")
|
||||||
|
print(f"tokt: {tokt}")
|
||||||
|
print(f"repo: {model['repo']}")
|
||||||
|
print(f"chktok: {chktok}")
|
||||||
|
print(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
|
# print the "pre_tokenizer" content from the tokenizer.json
|
||||||
|
with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
pre_tokenizer = cfg["pre_tokenizer"]
|
||||||
|
print("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4))
|
||||||
|
|
||||||
|
print(f"\n")
|
||||||
|
|
||||||
|
src_ifs += f" if chkhsh == \"{chkhsh}\":\n"
|
||||||
|
src_ifs += f" # ref: {model['repo']}\n"
|
||||||
|
src_ifs += f" res = \"{name}\"\n"
|
||||||
|
|
||||||
|
src_func = ""
|
||||||
|
src_func += " def get_vocab_base_pre(self, tokenizer) -> str:\n"
|
||||||
|
src_func += " # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that\n"
|
||||||
|
src_func += " # is specific for the BPE pre-tokenizer used by the model\n"
|
||||||
|
src_func += " # we will use this unique identifier to write a \"tokenizer.ggml.pre\" entry in the GGUF file which we can\n"
|
||||||
|
src_func += " # use in llama.cpp to implement the same pre-tokenizer\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += f" chktxt = {repr(chktxt)}\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " chktok = tokenizer.encode(chktxt)\n"
|
||||||
|
src_func += " chkhsh = sha256(str(chktok).encode()).hexdigest()\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " print(f\"chktok: {chktok}\")\n"
|
||||||
|
src_func += " print(f\"chkhsh: {chkhsh}\")\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " res = None\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script\n"
|
||||||
|
src_func += " # or pull the latest version of the model from Huggingface\n"
|
||||||
|
src_func += " # don't edit the hashes manually!\n"
|
||||||
|
src_func += f"{src_ifs}\n"
|
||||||
|
src_func += " if res is None:\n"
|
||||||
|
src_func += " print(\"\\n\")\n"
|
||||||
|
src_func += " print(\"**************************************************************************************\")\n"
|
||||||
|
src_func += " print(\"** WARNING: The BPE pre-tokenizer was not recognized!\")\n"
|
||||||
|
src_func += " print(\"** There are 2 possible reasons for this:\")\n"
|
||||||
|
src_func += " print(\"** - the model has not been added to convert-hf-to-gguf-update.py yet\")\n"
|
||||||
|
src_func += " print(\"** - the pre-tokenization config has changed upstream\")\n"
|
||||||
|
src_func += " print(\"** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.\")\n"
|
||||||
|
src_func += " print(\"** ref: https://github.com/ggerganov/llama.cpp/pull/6920\")\n"
|
||||||
|
src_func += " print(\"**\")\n"
|
||||||
|
src_func += " print(f\"** chkhsh: {chkhsh}\")\n"
|
||||||
|
src_func += " print(\"**************************************************************************************\")\n"
|
||||||
|
src_func += " print(\"\\n\")\n"
|
||||||
|
src_func += " raise NotImplementedError(\"BPE pre-tokenizer was not recognized - update get_vocab_base_pre()\")\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " print(f\"tokenizer.ggml.pre: {res}\")\n"
|
||||||
|
src_func += " print(f\"chkhsh: {chkhsh}\")\n"
|
||||||
|
src_func += "\n"
|
||||||
|
src_func += " return res\n"
|
||||||
|
|
||||||
|
print(src_func)
|
||||||
|
|
||||||
|
print("\n")
|
||||||
|
print("!!! Copy-paste the function above into convert-hf-to-gguf.py !!!")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
# generate tests for each tokenizer model
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
"",
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
" ",
|
||||||
|
"\t",
|
||||||
|
"\n",
|
||||||
|
"\n\n",
|
||||||
|
"\n\n\n",
|
||||||
|
"\t\n",
|
||||||
|
"Hello world",
|
||||||
|
" Hello world",
|
||||||
|
"Hello World",
|
||||||
|
" Hello World",
|
||||||
|
" Hello World!",
|
||||||
|
"Hello, world!",
|
||||||
|
" Hello, world!",
|
||||||
|
" this is 🦙.cpp",
|
||||||
|
"w048 7tuijk dsdfhu",
|
||||||
|
"нещо на Български",
|
||||||
|
"កាន់តែពិសេសអាចខលចេញ",
|
||||||
|
"🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
|
||||||
|
"Hello",
|
||||||
|
" Hello",
|
||||||
|
" Hello",
|
||||||
|
" Hello",
|
||||||
|
" Hello",
|
||||||
|
" Hello\n Hello",
|
||||||
|
" (",
|
||||||
|
"\n =",
|
||||||
|
"' era",
|
||||||
|
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
|
||||||
|
"3",
|
||||||
|
"33",
|
||||||
|
"333",
|
||||||
|
"3333",
|
||||||
|
"33333",
|
||||||
|
"333333",
|
||||||
|
"3333333",
|
||||||
|
"33333333",
|
||||||
|
"333333333",
|
||||||
|
chktxt,
|
||||||
|
]
|
||||||
|
|
||||||
|
# write the tests to ./models/ggml-vocab-{name}.gguf.inp
|
||||||
|
# the format is:
|
||||||
|
#
|
||||||
|
# test0
|
||||||
|
# __ggml_vocab_test__
|
||||||
|
# test1
|
||||||
|
# __ggml_vocab_test__
|
||||||
|
# ...
|
||||||
|
#
|
||||||
|
|
||||||
|
# with each model, encode all tests and write the results in ./models/ggml-vocab-{name}.gguf.out
|
||||||
|
# for each test, write the resulting tokens on a separate line
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
tokt = model["tokt"]
|
||||||
|
|
||||||
|
# create the tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
|
||||||
|
|
||||||
|
with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f:
|
||||||
|
for text in tests:
|
||||||
|
f.write(f"{text}")
|
||||||
|
f.write("\n__ggml_vocab_test__\n")
|
||||||
|
|
||||||
|
with open(f"models/ggml-vocab-{name}.gguf.out", "w") as f:
|
||||||
|
for text in tests:
|
||||||
|
res = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
for r in res:
|
||||||
|
f.write(f" {r}")
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
print(f"Tests for {name} written in ./models/ggml-vocab-{name}.gguf.*")
|
||||||
|
|
||||||
|
# generate commands for creating vocab files
|
||||||
|
|
||||||
|
print("\nRun the following commands to generate the vocab files for testing:\n")
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
name = model["name"]
|
||||||
|
|
||||||
|
print(f"python3 convert-hf-to-gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only")
|
||||||
|
|
||||||
|
print("\n")
|
|
@ -11,6 +11,7 @@ import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -229,7 +230,7 @@ class Model(ABC):
|
||||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||||
|
|
||||||
# used for GPT-2 BPE and WordPiece vocabs
|
# used for GPT-2 BPE and WordPiece vocabs
|
||||||
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
|
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||||
tokens: list[str] = []
|
tokens: list[str] = []
|
||||||
toktypes: list[int] = []
|
toktypes: list[int] = []
|
||||||
|
|
||||||
|
@ -238,6 +239,8 @@ class Model(ABC):
|
||||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||||
assert max(tokenizer.vocab.values()) < vocab_size
|
assert max(tokenizer.vocab.values()) < vocab_size
|
||||||
|
|
||||||
|
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||||
|
|
||||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||||
added_vocab = tokenizer.get_added_vocab()
|
added_vocab = tokenizer.get_added_vocab()
|
||||||
|
|
||||||
|
@ -255,11 +258,79 @@ class Model(ABC):
|
||||||
tokens.append(reverse_vocab[i])
|
tokens.append(reverse_vocab[i])
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
|
||||||
return tokens, toktypes
|
return tokens, toktypes, tokpre
|
||||||
|
|
||||||
|
# NOTE: this function is generated by convert-hf-to-gguf-update.py
|
||||||
|
# do not modify it manually!
|
||||||
|
# ref: https://github.com/ggerganov/llama.cpp/pull/6920
|
||||||
|
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||||
|
# encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that
|
||||||
|
# is specific for the BPE pre-tokenizer used by the model
|
||||||
|
# we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can
|
||||||
|
# use in llama.cpp to implement the same pre-tokenizer
|
||||||
|
|
||||||
|
chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
|
||||||
|
|
||||||
|
chktok = tokenizer.encode(chktxt)
|
||||||
|
chkhsh = sha256(str(chktok).encode()).hexdigest()
|
||||||
|
|
||||||
|
print(f"chktok: {chktok}")
|
||||||
|
print(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
|
res = None
|
||||||
|
|
||||||
|
# NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
|
||||||
|
# or pull the latest version of the model from Huggingface
|
||||||
|
# don't edit the hashes manually!
|
||||||
|
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||||
|
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||||
|
res = "llama-bpe"
|
||||||
|
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
|
||||||
|
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
|
||||||
|
res = "deepseek-llm"
|
||||||
|
if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821":
|
||||||
|
# ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base
|
||||||
|
res = "deepseek-coder"
|
||||||
|
if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed":
|
||||||
|
# ref: https://huggingface.co/tiiuae/falcon-7b
|
||||||
|
res = "falcon"
|
||||||
|
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||||
|
# ref: https://huggingface.co/BAAI/bge-small-en-v1.5
|
||||||
|
res = "bert-bge"
|
||||||
|
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||||
|
# ref: https://huggingface.co/mosaicml/mpt-7b
|
||||||
|
res = "mpt"
|
||||||
|
if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34":
|
||||||
|
# ref: https://huggingface.co/bigcode/starcoder2-3b
|
||||||
|
res = "starcoder"
|
||||||
|
if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454":
|
||||||
|
# ref: https://huggingface.co/openai-community/gpt2
|
||||||
|
res = "gpt-2"
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
print("\n")
|
||||||
|
print("**************************************************************************************")
|
||||||
|
print("** WARNING: The BPE pre-tokenizer was not recognized!")
|
||||||
|
print("** There are 2 possible reasons for this:")
|
||||||
|
print("** - the model has not been added to convert-hf-to-gguf-update.py yet")
|
||||||
|
print("** - the pre-tokenization config has changed upstream")
|
||||||
|
print("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
|
||||||
|
print("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
|
||||||
|
print("**")
|
||||||
|
print(f"** chkhsh: {chkhsh}")
|
||||||
|
print("**************************************************************************************")
|
||||||
|
print("\n")
|
||||||
|
raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
|
||||||
|
|
||||||
|
print(f"tokenizer.ggml.pre: {res}")
|
||||||
|
print(f"chkhsh: {chkhsh}")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def _set_vocab_gpt2(self) -> None:
|
def _set_vocab_gpt2(self) -> None:
|
||||||
tokens, toktypes = self.get_basic_vocab()
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||||
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
@ -277,6 +348,8 @@ class Model(ABC):
|
||||||
vocab_size = hparams["vocab_size"]
|
vocab_size = hparams["vocab_size"]
|
||||||
assert max(tokenizer.get_vocab().values()) < vocab_size
|
assert max(tokenizer.get_vocab().values()) < vocab_size
|
||||||
|
|
||||||
|
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||||
|
|
||||||
merges = []
|
merges = []
|
||||||
vocab = {}
|
vocab = {}
|
||||||
mergeable_ranks = tokenizer.mergeable_ranks
|
mergeable_ranks = tokenizer.mergeable_ranks
|
||||||
|
@ -304,6 +377,7 @@ class Model(ABC):
|
||||||
toktypes.append(gguf.TokenType.NORMAL)
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||||
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
@ -376,6 +450,7 @@ class Model(ABC):
|
||||||
assert len(tokens) == vocab_size
|
assert len(tokens) == vocab_size
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
|
self.gguf_writer.add_tokenizer_pre("default")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
@ -397,6 +472,7 @@ class Model(ABC):
|
||||||
assert len(tokens) == vocab.vocab_size
|
assert len(tokens) == vocab.vocab_size
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
|
self.gguf_writer.add_tokenizer_pre("default")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
@ -840,6 +916,7 @@ class XverseModel(Model):
|
||||||
toktypes.append(toktype)
|
toktypes.append(toktype)
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
|
self.gguf_writer.add_tokenizer_pre("default")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
@ -1335,6 +1412,11 @@ class LlamaModel(Model):
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||||
|
|
||||||
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
|
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||||
|
|
||||||
# Same as super class, but permuting q_proj, k_proj
|
# Same as super class, but permuting q_proj, k_proj
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
|
@ -2052,6 +2134,7 @@ class Phi3MiniModel(Model):
|
||||||
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
|
self.gguf_writer.add_tokenizer_pre("default")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
@ -2294,6 +2377,7 @@ class InternLM2Model(Model):
|
||||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||||
|
|
||||||
self.gguf_writer.add_tokenizer_model("llama")
|
self.gguf_writer.add_tokenizer_model("llama")
|
||||||
|
self.gguf_writer.add_tokenizer_pre("default")
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_scores(scores)
|
self.gguf_writer.add_token_scores(scores)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
@ -2443,7 +2527,7 @@ class BertModel(Model):
|
||||||
self.gguf_writer.add_pooling_type(pooling_type)
|
self.gguf_writer.add_pooling_type(pooling_type)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
tokens, toktypes = self.get_basic_vocab()
|
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||||
self.vocab_size = len(tokens)
|
self.vocab_size = len(tokens)
|
||||||
|
|
||||||
# we need this to validate the size of the token_type embeddings
|
# we need this to validate the size of the token_type embeddings
|
||||||
|
@ -2461,6 +2545,7 @@ class BertModel(Model):
|
||||||
|
|
||||||
# add vocab to gguf
|
# add vocab to gguf
|
||||||
self.gguf_writer.add_tokenizer_model("bert")
|
self.gguf_writer.add_tokenizer_model("bert")
|
||||||
|
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
@ -2482,6 +2567,10 @@ class BertModel(Model):
|
||||||
print(f"Can not map tensor {name!r}")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
new_dtype: type[np.floating[Any]]
|
new_dtype: type[np.floating[Any]]
|
||||||
|
@ -2638,6 +2727,9 @@ class MambaModel(Model):
|
||||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
|
||||||
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
|
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
|
||||||
|
|
||||||
|
field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE)
|
||||||
|
self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]))
|
||||||
|
|
||||||
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
|
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
|
||||||
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
|
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
|
||||||
|
|
||||||
|
@ -2843,6 +2935,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
help="directory containing model file",
|
help="directory containing model file",
|
||||||
)
|
)
|
||||||
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
||||||
|
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -281,6 +281,7 @@ class GGMLToGGUF:
|
||||||
def add_vocab(self, gguf_writer):
|
def add_vocab(self, gguf_writer):
|
||||||
hp = self.model.hyperparameters
|
hp = self.model.hyperparameters
|
||||||
gguf_writer.add_tokenizer_model('llama')
|
gguf_writer.add_tokenizer_model('llama')
|
||||||
|
gguf_writer.add_tokenizer_pre('default')
|
||||||
tokens = []
|
tokens = []
|
||||||
scores = []
|
scores = []
|
||||||
toktypes = []
|
toktypes = []
|
||||||
|
|
|
@ -99,6 +99,7 @@ def main():
|
||||||
|
|
||||||
tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
|
tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
|
||||||
gguf_writer.add_tokenizer_model('llama')
|
gguf_writer.add_tokenizer_model('llama')
|
||||||
|
gguf_writer.add_tokenizer_pre('default')
|
||||||
gguf_writer.add_token_list(tokens)
|
gguf_writer.add_token_list(tokens)
|
||||||
gguf_writer.add_token_scores(scores)
|
gguf_writer.add_token_scores(scores)
|
||||||
gguf_writer.add_token_types(toktypes)
|
gguf_writer.add_token_types(toktypes)
|
||||||
|
|
|
@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (argc == 1 || argv[1][0] == '-') {
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
|
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
|
||||||
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
|
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
|
||||||
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
|
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
|
||||||
return 1 ;
|
return 1 ;
|
||||||
|
@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
|
||||||
int n_kv_max = 2048;
|
int n_kv_max = 2048;
|
||||||
int n_batch = 2048;
|
int n_batch = 2048;
|
||||||
int n_ubatch = 512;
|
int n_ubatch = 512;
|
||||||
|
bool flash_attn = false;
|
||||||
int is_pp_shared = 0;
|
int is_pp_shared = 0;
|
||||||
int n_gpu_layers = 0;
|
int n_gpu_layers = 0;
|
||||||
|
|
||||||
|
@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 6) {
|
if (argc >= 6) {
|
||||||
is_pp_shared = std::atoi(argv[5]);
|
flash_attn = std::atoi(argv[5]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 7) {
|
if (argc >= 7) {
|
||||||
n_gpu_layers = std::atoi(argv[6]);
|
is_pp_shared = std::atoi(argv[6]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 8) {
|
if (argc >= 8) {
|
||||||
n_pp = parse_list(argv[7]);
|
n_gpu_layers = std::atoi(argv[7]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 9) {
|
if (argc >= 9) {
|
||||||
n_tg = parse_list(argv[8]);
|
n_pp = parse_list(argv[8]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (argc >= 10) {
|
if (argc >= 10) {
|
||||||
n_pl = parse_list(argv[9]);
|
n_tg = parse_list(argv[9]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc >= 11) {
|
||||||
|
n_pl = parse_list(argv[10]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// init LLM
|
// init LLM
|
||||||
|
@ -112,6 +117,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_params.n_ctx = n_kv_max;
|
ctx_params.n_ctx = n_kv_max;
|
||||||
ctx_params.n_batch = n_batch;
|
ctx_params.n_batch = n_batch;
|
||||||
ctx_params.n_ubatch = n_ubatch;
|
ctx_params.n_ubatch = n_ubatch;
|
||||||
|
ctx_params.flash_attn = flash_attn;
|
||||||
|
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
|
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
|
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
|
||||||
|
|
|
@ -24,6 +24,7 @@ struct Stats {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct StatParams {
|
struct StatParams {
|
||||||
|
std::string dataset;
|
||||||
std::string ofile = "imatrix.dat";
|
std::string ofile = "imatrix.dat";
|
||||||
int n_output_frequency = 10;
|
int n_output_frequency = 10;
|
||||||
int verbosity = 1;
|
int verbosity = 1;
|
||||||
|
@ -47,7 +48,7 @@ private:
|
||||||
std::vector<float> m_src1_data;
|
std::vector<float> m_src1_data;
|
||||||
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
|
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
|
||||||
//
|
//
|
||||||
void save_imatrix(const char * file_name) const;
|
void save_imatrix(const char * file_name, const char * dataset) const;
|
||||||
void keep_imatrix(int ncall) const;
|
void keep_imatrix(int ncall) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -200,7 +201,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
|
||||||
}
|
}
|
||||||
|
|
||||||
void IMatrixCollector::save_imatrix() const {
|
void IMatrixCollector::save_imatrix() const {
|
||||||
save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str());
|
save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void IMatrixCollector::keep_imatrix(int ncall) const {
|
void IMatrixCollector::keep_imatrix(int ncall) const {
|
||||||
|
@ -208,14 +209,14 @@ void IMatrixCollector::keep_imatrix(int ncall) const {
|
||||||
if (file_name.empty()) file_name = "imatrix.dat";
|
if (file_name.empty()) file_name = "imatrix.dat";
|
||||||
file_name += ".at_";
|
file_name += ".at_";
|
||||||
file_name += std::to_string(ncall);
|
file_name += std::to_string(ncall);
|
||||||
save_imatrix(file_name.c_str());
|
save_imatrix(file_name.c_str(), m_params.dataset.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void IMatrixCollector::save_imatrix(const char * fname) const {
|
void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const {
|
||||||
std::ofstream out(fname, std::ios::binary);
|
std::ofstream out(fname, std::ios::binary);
|
||||||
int n_entries = m_stats.size();
|
int n_entries = m_stats.size();
|
||||||
out.write((const char *) &n_entries, sizeof(n_entries));
|
out.write((const char *) &n_entries, sizeof(n_entries));
|
||||||
for (auto& p : m_stats) {
|
for (const auto & p : m_stats) {
|
||||||
int len = p.first.size();
|
int len = p.first.size();
|
||||||
out.write((const char *) &len, sizeof(len));
|
out.write((const char *) &len, sizeof(len));
|
||||||
out.write(p.first.c_str(), len);
|
out.write(p.first.c_str(), len);
|
||||||
|
@ -224,6 +225,15 @@ void IMatrixCollector::save_imatrix(const char * fname) const {
|
||||||
out.write((const char *) &nval, sizeof(nval));
|
out.write((const char *) &nval, sizeof(nval));
|
||||||
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
|
if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write the number of call the matrix was computed with
|
||||||
|
out.write((const char *) &m_last_call, sizeof(m_last_call));
|
||||||
|
|
||||||
|
// Write the dataset name at the end of the file to later on specify it in quantize
|
||||||
|
int n_dataset = strlen(dataset);
|
||||||
|
out.write((const char *) &n_dataset, sizeof(n_dataset));
|
||||||
|
out.write(dataset, n_dataset);
|
||||||
|
|
||||||
if (m_params.verbosity > 0) {
|
if (m_params.verbosity > 0) {
|
||||||
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname);
|
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname);
|
||||||
}
|
}
|
||||||
|
@ -548,6 +558,29 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
params.n_batch = 512;
|
||||||
|
if (!gpt_params_parse(args.size(), args.data(), params)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.logits_all = true;
|
||||||
|
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||||
|
|
||||||
|
print_build_info();
|
||||||
|
|
||||||
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
params.seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
||||||
|
|
||||||
|
std::mt19937 rng(params.seed);
|
||||||
|
if (params.random_prompt) {
|
||||||
|
params.prompt = gpt_random_prompt(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
sparams.dataset = params.prompt_file;
|
||||||
g_collector.set_parameters(std::move(sparams));
|
g_collector.set_parameters(std::move(sparams));
|
||||||
|
|
||||||
if (!combine_files.empty()) {
|
if (!combine_files.empty()) {
|
||||||
|
@ -586,28 +619,6 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt_params params;
|
|
||||||
params.n_batch = 512;
|
|
||||||
if (!gpt_params_parse(args.size(), args.data(), params)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
params.logits_all = true;
|
|
||||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
|
||||||
|
|
||||||
print_build_info();
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
|
||||||
|
|
||||||
std::mt19937 rng(params.seed);
|
|
||||||
if (params.random_prompt) {
|
|
||||||
params.prompt = gpt_random_prompt(rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params.numa);
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
|
|
|
@ -175,6 +175,7 @@ struct cmd_params {
|
||||||
std::vector<llama_split_mode> split_mode;
|
std::vector<llama_split_mode> split_mode;
|
||||||
std::vector<int> main_gpu;
|
std::vector<int> main_gpu;
|
||||||
std::vector<bool> no_kv_offload;
|
std::vector<bool> no_kv_offload;
|
||||||
|
std::vector<bool> flash_attn;
|
||||||
std::vector<std::vector<float>> tensor_split;
|
std::vector<std::vector<float>> tensor_split;
|
||||||
std::vector<bool> use_mmap;
|
std::vector<bool> use_mmap;
|
||||||
std::vector<bool> embeddings;
|
std::vector<bool> embeddings;
|
||||||
|
@ -196,6 +197,7 @@ static const cmd_params cmd_params_defaults = {
|
||||||
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
||||||
/* main_gpu */ {0},
|
/* main_gpu */ {0},
|
||||||
/* no_kv_offload */ {false},
|
/* no_kv_offload */ {false},
|
||||||
|
/* flash_attn */ {false},
|
||||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||||
/* use_mmap */ {true},
|
/* use_mmap */ {true},
|
||||||
/* embeddings */ {false},
|
/* embeddings */ {false},
|
||||||
|
@ -221,6 +223,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||||
|
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||||
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
|
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
|
||||||
|
@ -394,6 +397,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
auto p = split<bool>(argv[i], split_delim);
|
auto p = split<bool>(argv[i], split_delim);
|
||||||
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
|
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
|
||||||
|
} else if (arg == "-fa" || arg == "--flash-attn") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto p = split<bool>(argv[i], split_delim);
|
||||||
|
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
|
||||||
} else if (arg == "-mmp" || arg == "--mmap") {
|
} else if (arg == "-mmp" || arg == "--mmap") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -478,6 +488,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
||||||
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
||||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||||
|
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
|
||||||
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
||||||
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
||||||
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
||||||
|
@ -499,6 +510,7 @@ struct cmd_params_instance {
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
|
bool flash_attn;
|
||||||
std::vector<float> tensor_split;
|
std::vector<float> tensor_split;
|
||||||
bool use_mmap;
|
bool use_mmap;
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
|
@ -533,6 +545,7 @@ struct cmd_params_instance {
|
||||||
cparams.type_k = type_k;
|
cparams.type_k = type_k;
|
||||||
cparams.type_v = type_v;
|
cparams.type_v = type_v;
|
||||||
cparams.offload_kqv = !no_kv_offload;
|
cparams.offload_kqv = !no_kv_offload;
|
||||||
|
cparams.flash_attn = flash_attn;
|
||||||
cparams.embeddings = embeddings;
|
cparams.embeddings = embeddings;
|
||||||
|
|
||||||
return cparams;
|
return cparams;
|
||||||
|
@ -555,6 +568,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
for (const auto & tk : params.type_k)
|
for (const auto & tk : params.type_k)
|
||||||
for (const auto & tv : params.type_v)
|
for (const auto & tv : params.type_v)
|
||||||
for (const auto & nkvo : params.no_kv_offload)
|
for (const auto & nkvo : params.no_kv_offload)
|
||||||
|
for (const auto & fa : params.flash_attn)
|
||||||
for (const auto & nt : params.n_threads) {
|
for (const auto & nt : params.n_threads) {
|
||||||
for (const auto & n_prompt : params.n_prompt) {
|
for (const auto & n_prompt : params.n_prompt) {
|
||||||
if (n_prompt == 0) {
|
if (n_prompt == 0) {
|
||||||
|
@ -573,6 +587,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
/* .flash_attn = */ fa,
|
||||||
/* .tensor_split = */ ts,
|
/* .tensor_split = */ ts,
|
||||||
/* .use_mmap = */ mmp,
|
/* .use_mmap = */ mmp,
|
||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
|
@ -597,6 +612,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
/* .flash_attn = */ fa,
|
||||||
/* .tensor_split = */ ts,
|
/* .tensor_split = */ ts,
|
||||||
/* .use_mmap = */ mmp,
|
/* .use_mmap = */ mmp,
|
||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
|
@ -634,6 +650,7 @@ struct test {
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
|
bool flash_attn;
|
||||||
std::vector<float> tensor_split;
|
std::vector<float> tensor_split;
|
||||||
bool use_mmap;
|
bool use_mmap;
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
|
@ -658,6 +675,7 @@ struct test {
|
||||||
split_mode = inst.split_mode;
|
split_mode = inst.split_mode;
|
||||||
main_gpu = inst.main_gpu;
|
main_gpu = inst.main_gpu;
|
||||||
no_kv_offload = inst.no_kv_offload;
|
no_kv_offload = inst.no_kv_offload;
|
||||||
|
flash_attn = inst.flash_attn;
|
||||||
tensor_split = inst.tensor_split;
|
tensor_split = inst.tensor_split;
|
||||||
use_mmap = inst.use_mmap;
|
use_mmap = inst.use_mmap;
|
||||||
embeddings = inst.embeddings;
|
embeddings = inst.embeddings;
|
||||||
|
@ -732,7 +750,7 @@ struct test {
|
||||||
"n_batch", "n_ubatch",
|
"n_batch", "n_ubatch",
|
||||||
"n_threads", "type_k", "type_v",
|
"n_threads", "type_k", "type_v",
|
||||||
"n_gpu_layers", "split_mode",
|
"n_gpu_layers", "split_mode",
|
||||||
"main_gpu", "no_kv_offload",
|
"main_gpu", "no_kv_offload", "flash_attn",
|
||||||
"tensor_split", "use_mmap", "embeddings",
|
"tensor_split", "use_mmap", "embeddings",
|
||||||
"n_prompt", "n_gen", "test_time",
|
"n_prompt", "n_gen", "test_time",
|
||||||
"avg_ns", "stddev_ns",
|
"avg_ns", "stddev_ns",
|
||||||
|
@ -754,7 +772,7 @@ struct test {
|
||||||
}
|
}
|
||||||
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||||
field == "use_mmap" || field == "embeddings") {
|
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
|
||||||
return BOOL;
|
return BOOL;
|
||||||
}
|
}
|
||||||
if (field == "avg_ts" || field == "stddev_ts") {
|
if (field == "avg_ts" || field == "stddev_ts") {
|
||||||
|
@ -788,7 +806,7 @@ struct test {
|
||||||
std::to_string(n_batch), std::to_string(n_ubatch),
|
std::to_string(n_batch), std::to_string(n_ubatch),
|
||||||
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||||
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||||
std::to_string(main_gpu), std::to_string(no_kv_offload),
|
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
||||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
||||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||||
|
@ -956,6 +974,9 @@ struct markdown_printer : public printer {
|
||||||
if (field == "no_kv_offload") {
|
if (field == "no_kv_offload") {
|
||||||
return "nkvo";
|
return "nkvo";
|
||||||
}
|
}
|
||||||
|
if (field == "flash_attn") {
|
||||||
|
return "fa";
|
||||||
|
}
|
||||||
if (field == "use_mmap") {
|
if (field == "use_mmap") {
|
||||||
return "mmap";
|
return "mmap";
|
||||||
}
|
}
|
||||||
|
@ -1002,6 +1023,9 @@ struct markdown_printer : public printer {
|
||||||
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
|
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
|
||||||
fields.emplace_back("no_kv_offload");
|
fields.emplace_back("no_kv_offload");
|
||||||
}
|
}
|
||||||
|
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
|
||||||
|
fields.emplace_back("flash_attn");
|
||||||
|
}
|
||||||
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
||||||
fields.emplace_back("tensor_split");
|
fields.emplace_back("tensor_split");
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,11 +113,11 @@ struct llava_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||||
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||||
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
|
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) {
|
||||||
|
|
||||||
// load and preprocess the image
|
// load and preprocess the image
|
||||||
llava_image_embed * embed = NULL;
|
llava_image_embed * embed = NULL;
|
||||||
|
@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
|
||||||
}
|
}
|
||||||
params->prompt = remove_image_from_prompt(prompt);
|
params->prompt = remove_image_from_prompt(prompt);
|
||||||
} else {
|
} else {
|
||||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
|
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str());
|
||||||
if (!embed) {
|
if (!embed) {
|
||||||
LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str());
|
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static struct llama_model * llava_init(gpt_params * params) {
|
||||||
static struct llava_context * llava_init(gpt_params * params) {
|
|
||||||
const char * clip_path = params->mmproj.c_str();
|
|
||||||
|
|
||||||
auto prompt = params->prompt;
|
|
||||||
if (prompt.empty()) {
|
|
||||||
prompt = "describe the image in detail.";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
llama_numa_init(params->numa);
|
llama_numa_init(params->numa);
|
||||||
|
|
||||||
|
@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) {
|
||||||
LOG_TEE("%s: error: unable to load model\n" , __func__);
|
LOG_TEE("%s: error: unable to load model\n" , __func__);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
|
||||||
|
const char * clip_path = params->mmproj.c_str();
|
||||||
|
|
||||||
|
auto prompt = params->prompt;
|
||||||
|
if (prompt.empty()) {
|
||||||
|
prompt = "describe the image in detail.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||||
|
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
|
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
|
||||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||||
|
@ -286,15 +289,18 @@ int main(int argc, char ** argv) {
|
||||||
show_additional_info(argc, argv);
|
show_additional_info(argc, argv);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
auto model = llava_init(¶ms);
|
||||||
auto ctx_llava = llava_init(¶ms);
|
if (model == NULL) {
|
||||||
if (ctx_llava == NULL) {
|
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
|
||||||
LOG_TEE("%s: error: failed to init llava\n", __func__);
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto image_embed = load_image(ctx_llava, ¶ms);
|
for (auto & image : params.image) {
|
||||||
|
auto ctx_llava = llava_init_context(¶ms, model);
|
||||||
|
|
||||||
|
auto image_embed = load_image(ctx_llava, ¶ms, image);
|
||||||
if (!image_embed) {
|
if (!image_embed) {
|
||||||
|
std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,8 +308,11 @@ int main(int argc, char ** argv) {
|
||||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||||
|
|
||||||
llama_print_timings(ctx_llava->ctx_llama);
|
llama_print_timings(ctx_llava->ctx_llama);
|
||||||
|
|
||||||
llava_image_embed_free(image_embed);
|
llava_image_embed_free(image_embed);
|
||||||
|
ctx_llava->model = NULL;
|
||||||
llava_free(ctx_llava);
|
llava_free(ctx_llava);
|
||||||
|
}
|
||||||
|
llama_free_model(model);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,9 @@ In this case, CLBlast was already installed so the CMake package is referenced i
|
||||||
```cmd
|
```cmd
|
||||||
git clone https://github.com/ggerganov/llama.cpp
|
git clone https://github.com/ggerganov/llama.cpp
|
||||||
cd llama.cpp
|
cd llama.cpp
|
||||||
mkdir build
|
cmake -B build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH=C:/CLBlast/lib/cmake/CLBlast -G "Visual Studio 17 2022" -A x64
|
||||||
cd build
|
cmake --build build --config Release
|
||||||
cmake .. -DBUILD_SHARED_LIBS=OFF -DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH=C:/CLBlast/lib/cmake/CLBlast -G "Visual Studio 17 2022" -A x64
|
cmake --install build --prefix C:/LlamaCPP
|
||||||
cmake --build . --config Release
|
|
||||||
cmake --install . --prefix C:/LlamaCPP
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build main-cmake-pkg
|
### Build main-cmake-pkg
|
||||||
|
@ -29,9 +27,7 @@ cmake --install . --prefix C:/LlamaCPP
|
||||||
|
|
||||||
```cmd
|
```cmd
|
||||||
cd ..\examples\main-cmake-pkg
|
cd ..\examples\main-cmake-pkg
|
||||||
mkdir build
|
cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/CLBlast/lib/cmake/CLBlast;C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64
|
||||||
cd build
|
cmake --build build --config Release
|
||||||
cmake .. -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/CLBlast/lib/cmake/CLBlast;C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64
|
cmake --install build --prefix C:/MyLlamaApp
|
||||||
cmake --build . --config Release
|
|
||||||
cmake --install . --prefix C:/MyLlamaApp
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -66,7 +66,7 @@ main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt
|
||||||
|
|
||||||
In this section, we cover the most commonly used options for running the `main` program with the LLaMA models:
|
In this section, we cover the most commonly used options for running the `main` program with the LLaMA models:
|
||||||
|
|
||||||
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
|
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`; inferred from `--model-url` if set).
|
||||||
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf).
|
- `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf).
|
||||||
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
|
||||||
- `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models.
|
- `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models.
|
||||||
|
|
|
@ -325,7 +325,7 @@ int main(int argc, char ** argv) {
|
||||||
log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size());
|
log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size());
|
||||||
|
|
||||||
// if we will use the cache for the full prompt without reaching the end of the cache, force
|
// if we will use the cache for the full prompt without reaching the end of the cache, force
|
||||||
// reevaluation of the last token token to recalculate the cached logits
|
// reevaluation of the last token to recalculate the cached logits
|
||||||
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
|
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
|
||||||
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct quantize_stats_params {
|
struct quantize_stats_params {
|
||||||
std::string model = "models/7B/ggml-model-f16.gguf";
|
std::string model = DEFAULT_MODEL_PATH;
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
bool per_layer_stats = false;
|
bool per_layer_stats = false;
|
||||||
bool print_histogram = false;
|
bool print_histogram = false;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
set(TARGET quantize)
|
set(TARGET quantize)
|
||||||
add_executable(${TARGET} quantize.cpp)
|
add_executable(${TARGET} quantize.cpp)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
struct quant_option {
|
struct quant_option {
|
||||||
std::string name;
|
std::string name;
|
||||||
|
@ -54,6 +53,10 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
|
||||||
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
|
||||||
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
||||||
|
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
|
||||||
|
|
||||||
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
|
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
|
||||||
std::string ftype_str;
|
std::string ftype_str;
|
||||||
|
@ -114,7 +117,7 @@ static void usage(const char * executable) {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void load_imatrix(const std::string & imatrix_file, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||||
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
|
std::ifstream in(imatrix_file.c_str(), std::ios::binary);
|
||||||
if (!in) {
|
if (!in) {
|
||||||
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
|
printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
|
||||||
|
@ -161,18 +164,33 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
|
||||||
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
|
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("%s: loaded %d importance matrix entries from %s\n", __func__, int(imatrix_data.size()), imatrix_file.c_str());
|
|
||||||
|
// latest imatrix version contains the dataset filename at the end of the file
|
||||||
|
int m_last_call = 0;
|
||||||
|
if (in.peek() != EOF) {
|
||||||
|
in.read((char *)&m_last_call, sizeof(m_last_call));
|
||||||
|
int dataset_len;
|
||||||
|
in.read((char *)&dataset_len, sizeof(dataset_len));
|
||||||
|
std::vector<char> dataset_as_vec(dataset_len);
|
||||||
|
in.read(dataset_as_vec.data(), dataset_len);
|
||||||
|
imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end());
|
||||||
|
printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str());
|
||||||
|
}
|
||||||
|
printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call);
|
||||||
|
return m_last_call;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void prepare_imatrix(const std::string & imatrix_file,
|
static int prepare_imatrix(const std::string & imatrix_file,
|
||||||
|
std::string & imatrix_dataset,
|
||||||
const std::vector<std::string> & included_weights,
|
const std::vector<std::string> & included_weights,
|
||||||
const std::vector<std::string> & excluded_weights,
|
const std::vector<std::string> & excluded_weights,
|
||||||
std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
|
||||||
|
int m_last_call = -1;
|
||||||
if (!imatrix_file.empty()) {
|
if (!imatrix_file.empty()) {
|
||||||
load_imatrix(imatrix_file, imatrix_data);
|
m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data);
|
||||||
}
|
}
|
||||||
if (imatrix_data.empty()) {
|
if (imatrix_data.empty()) {
|
||||||
return;
|
return m_last_call;
|
||||||
}
|
}
|
||||||
if (!excluded_weights.empty()) {
|
if (!excluded_weights.empty()) {
|
||||||
for (auto& name : excluded_weights) {
|
for (auto& name : excluded_weights) {
|
||||||
|
@ -198,6 +216,7 @@ static void prepare_imatrix(const std::string & imatrix_file,
|
||||||
if (!imatrix_data.empty()) {
|
if (!imatrix_data.empty()) {
|
||||||
printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size()));
|
printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size()));
|
||||||
}
|
}
|
||||||
|
return m_last_call;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_type parse_ggml_type(const char * arg) {
|
static ggml_type parse_ggml_type(const char * arg) {
|
||||||
|
@ -212,43 +231,6 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
|
||||||
const char* sep = strchr(data, '=');
|
|
||||||
if (sep == nullptr || sep - data >= 128) {
|
|
||||||
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
llama_model_kv_override kvo;
|
|
||||||
std::strncpy(kvo.key, data, sep - data);
|
|
||||||
kvo.key[sep - data] = 0;
|
|
||||||
sep++;
|
|
||||||
if (strncmp(sep, "int:", 4) == 0) {
|
|
||||||
sep += 4;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
|
||||||
kvo.int_value = std::atol(sep);
|
|
||||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
|
||||||
sep += 6;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
|
||||||
kvo.float_value = std::atof(sep);
|
|
||||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
|
||||||
sep += 5;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
|
||||||
if (std::strcmp(sep, "true") == 0) {
|
|
||||||
kvo.bool_value = true;
|
|
||||||
} else if (std::strcmp(sep, "false") == 0) {
|
|
||||||
kvo.bool_value = false;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
overrides.emplace_back(std::move(kvo));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
if (argc < 3) {
|
if (argc < 3) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
|
@ -317,10 +299,43 @@ int main(int argc, char ** argv) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string imatrix_dataset;
|
||||||
std::unordered_map<std::string, std::vector<float>> imatrix_data;
|
std::unordered_map<std::string, std::vector<float>> imatrix_data;
|
||||||
prepare_imatrix(imatrix_file, included_weights, excluded_weights, imatrix_data);
|
int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data);
|
||||||
if (!imatrix_data.empty()) {
|
if (!imatrix_data.empty()) {
|
||||||
params.imatrix = &imatrix_data;
|
params.imatrix = &imatrix_data;
|
||||||
|
{
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||||
|
strncpy(kvo.val_str, imatrix_file.c_str(), 127);
|
||||||
|
kvo.val_str[127] = '\0';
|
||||||
|
kv_overrides.emplace_back(std::move(kvo));
|
||||||
|
}
|
||||||
|
if (!imatrix_dataset.empty()) {
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET);
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||||
|
strncpy(kvo.val_str, imatrix_dataset.c_str(), 127);
|
||||||
|
kvo.val_str[127] = '\0';
|
||||||
|
kv_overrides.emplace_back(std::move(kvo));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||||
|
kvo.val_i64 = imatrix_data.size();
|
||||||
|
kv_overrides.emplace_back(std::move(kvo));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_last_call > 0) {
|
||||||
|
llama_model_kv_override kvo;
|
||||||
|
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
|
||||||
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||||
|
kvo.val_i64 = m_last_call;
|
||||||
|
kv_overrides.emplace_back(std::move(kvo));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!kv_overrides.empty()) {
|
if (!kv_overrides.empty()) {
|
||||||
kv_overrides.emplace_back();
|
kv_overrides.emplace_back();
|
||||||
|
|
|
@ -74,15 +74,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
|
||||||
- Using `make`:
|
- Using `make`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make
|
make server
|
||||||
```
|
```
|
||||||
|
|
||||||
- Using `CMake`:
|
- Using `CMake`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmake --build . --config Release
|
cmake -B build
|
||||||
|
cmake --build build --config Release -t server
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Binary is at `./build/bin/server`
|
||||||
|
|
||||||
## Build with SSL
|
## Build with SSL
|
||||||
|
|
||||||
`server` can also be built with SSL support using OpenSSL 3
|
`server` can also be built with SSL support using OpenSSL 3
|
||||||
|
@ -99,10 +102,8 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
|
||||||
- Using `CMake`:
|
- Using `CMake`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mkdir build
|
cmake -B build -DLLAMA_SERVER_SSL=ON
|
||||||
cd build
|
cmake --build build --config Release -t server
|
||||||
cmake .. -DLLAMA_SERVER_SSL=ON
|
|
||||||
make server
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
|
@ -268,6 +268,7 @@ def start_server_background(args):
|
||||||
server_args.extend(['--defrag-thold', "0.1"])
|
server_args.extend(['--defrag-thold', "0.1"])
|
||||||
server_args.append('--cont-batching')
|
server_args.append('--cont-batching')
|
||||||
server_args.append('--metrics')
|
server_args.append('--metrics')
|
||||||
|
server_args.append('--flash-attn')
|
||||||
server_args.extend(['--log-format', "text"])
|
server_args.extend(['--log-format', "text"])
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
|
|
@ -90,7 +90,8 @@ export default function () {
|
||||||
"model": model,
|
"model": model,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"max_tokens": max_tokens
|
"max_tokens": max_tokens,
|
||||||
|
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
|
||||||
}
|
}
|
||||||
|
|
||||||
const params = {method: 'POST', body: JSON.stringify(payload)};
|
const params = {method: 'POST', body: JSON.stringify(payload)};
|
||||||
|
|
|
@ -1208,6 +1208,27 @@ struct server_context {
|
||||||
LOG_VERBOSE("eos token found", {});
|
LOG_VERBOSE("eos token found", {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto n_ctx_train = llama_n_ctx_train(model);
|
||||||
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
|
||||||
|
&& slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
||||||
|
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
||||||
|
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
||||||
|
{ "id_slot", slot.id },
|
||||||
|
{ "params.n_predict", slot.params.n_predict },
|
||||||
|
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
||||||
|
{ "slot.n_decoded", slot.n_decoded },
|
||||||
|
{ "slot.n_predict", slot.n_predict },
|
||||||
|
{ "n_slots", params.n_parallel },
|
||||||
|
{ "slot.n_ctx", slot.n_ctx },
|
||||||
|
{ "n_ctx", n_ctx },
|
||||||
|
{ "n_ctx_train", n_ctx_train },
|
||||||
|
{ "ga_n", slot.ga_n },
|
||||||
|
});
|
||||||
|
slot.truncated = true;
|
||||||
|
slot.stopped_limit = true;
|
||||||
|
slot.has_next_token = false; // stop prediction
|
||||||
|
}
|
||||||
|
|
||||||
LOG_VERBOSE("next token", {
|
LOG_VERBOSE("next token", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"id_task", slot.id_task},
|
{"id_task", slot.id_task},
|
||||||
|
@ -2142,7 +2163,7 @@ struct server_context {
|
||||||
});
|
});
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
@ -2333,7 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||||
printf(" disable KV offload\n");
|
printf(" disable KV offload\n");
|
||||||
}
|
}
|
||||||
printf(" -m FNAME, --model FNAME\n");
|
printf(" -m FNAME, --model FNAME\n");
|
||||||
printf(" model path (default: %s)\n", params.model.c_str());
|
printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH);
|
||||||
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
|
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
|
||||||
printf(" model download url (default: unused)\n");
|
printf(" model download url (default: unused)\n");
|
||||||
printf(" -hfr REPO, --hf-repo REPO\n");
|
printf(" -hfr REPO, --hf-repo REPO\n");
|
||||||
|
@ -2357,6 +2378,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||||
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
|
||||||
|
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
|
||||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||||
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||||
printf(" -ctk TYPE, --cache-type-k TYPE\n");
|
printf(" -ctk TYPE, --cache-type-k TYPE\n");
|
||||||
|
@ -2372,7 +2394,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
|
||||||
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||||
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n");
|
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n");
|
||||||
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n");
|
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n");
|
||||||
printf(" --chat-template JINJA_TEMPLATE\n");
|
printf(" --chat-template JINJA_TEMPLATE\n");
|
||||||
|
@ -2722,6 +2744,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
params.embedding = true;
|
params.embedding = true;
|
||||||
} else if (arg == "-cb" || arg == "--cont-batching") {
|
} else if (arg == "-cb" || arg == "--cont-batching") {
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
|
} else if (arg == "-fa" || arg == "--flash-attn") {
|
||||||
|
params.flash_attn = true;
|
||||||
} else if (arg == "-np" || arg == "--parallel") {
|
} else if (arg == "-np" || arg == "--parallel") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -2803,43 +2827,11 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
char * sep = strchr(argv[i], '=');
|
if (!parse_kv_override(argv[i], params.kv_overrides)) {
|
||||||
if (sep == nullptr || sep - argv[i] >= 128) {
|
|
||||||
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llama_model_kv_override kvo;
|
|
||||||
std::strncpy(kvo.key, argv[i], sep - argv[i]);
|
|
||||||
kvo.key[sep - argv[i]] = 0;
|
|
||||||
sep++;
|
|
||||||
if (strncmp(sep, "int:", 4) == 0) {
|
|
||||||
sep += 4;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
|
||||||
kvo.int_value = std::atol(sep);
|
|
||||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
|
||||||
sep += 6;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
|
||||||
kvo.float_value = std::atof(sep);
|
|
||||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
|
||||||
sep += 5;
|
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
|
||||||
if (std::strcmp(sep, "true") == 0) {
|
|
||||||
kvo.bool_value = true;
|
|
||||||
} else if (std::strcmp(sep, "false") == 0) {
|
|
||||||
kvo.bool_value = false;
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
|
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.kv_overrides.push_back(kvo);
|
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
server_print_usage(argv[0], default_params, default_sparams);
|
server_print_usage(argv[0], default_params, default_sparams);
|
||||||
|
@ -2847,6 +2839,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gpt_params_handle_model_default(params);
|
||||||
|
|
||||||
if (!params.kv_overrides.empty()) {
|
if (!params.kv_overrides.empty()) {
|
||||||
params.kv_overrides.emplace_back();
|
params.kv_overrides.emplace_back();
|
||||||
params.kv_overrides.back().key[0] = 0;
|
params.kv_overrides.back().key[0] = 0;
|
||||||
|
|
|
@ -5,7 +5,7 @@ Feature: llama.cpp server
|
||||||
Background: Server startup
|
Background: Server startup
|
||||||
Given a server listening on localhost:8080
|
Given a server listening on localhost:8080
|
||||||
And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
|
And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
|
||||||
And a model file ggml-model-f16.gguf
|
And a model file bert-bge-small.gguf
|
||||||
And a model alias bert-bge-small
|
And a model alias bert-bge-small
|
||||||
And 42 as server seed
|
And 42 as server seed
|
||||||
And 2 slots
|
And 2 slots
|
||||||
|
|
|
@ -1784,12 +1784,14 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
|
||||||
|
|
||||||
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
||||||
// reset state for the next run
|
// reset state for the next run
|
||||||
|
if (!sched->is_reset) {
|
||||||
size_t hash_size = sched->hash_set.size;
|
size_t hash_size = sched->hash_set.size;
|
||||||
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT
|
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT
|
||||||
memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
|
memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
|
||||||
memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size);
|
memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size);
|
||||||
|
|
||||||
sched->is_reset = true;
|
sched->is_reset = true;
|
||||||
|
}
|
||||||
sched->is_alloc = false;
|
sched->is_alloc = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ static bool g_mul_mat_q = false;
|
||||||
#include "ggml-cuda/cpy.cuh"
|
#include "ggml-cuda/cpy.cuh"
|
||||||
#include "ggml-cuda/diagmask.cuh"
|
#include "ggml-cuda/diagmask.cuh"
|
||||||
#include "ggml-cuda/dmmv.cuh"
|
#include "ggml-cuda/dmmv.cuh"
|
||||||
|
#include "ggml-cuda/fattn.cuh"
|
||||||
#include "ggml-cuda/getrows.cuh"
|
#include "ggml-cuda/getrows.cuh"
|
||||||
#include "ggml-cuda/im2col.cuh"
|
#include "ggml-cuda/im2col.cuh"
|
||||||
#include "ggml-cuda/mmq.cuh"
|
#include "ggml-cuda/mmq.cuh"
|
||||||
|
@ -142,6 +143,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||||
|
info.devices[id].nsm = prop.multiProcessorCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int id = 0; id < info.device_count; ++id) {
|
for (int id = 0; id < info.device_count; ++id) {
|
||||||
|
@ -2296,6 +2298,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
ggml_cuda_op_argsort(ctx, dst);
|
ggml_cuda_op_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -2570,6 +2575,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -142,6 +142,7 @@
|
||||||
#define CC_PASCAL 600
|
#define CC_PASCAL 600
|
||||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||||
#define CC_VOLTA 700
|
#define CC_VOLTA 700
|
||||||
|
#define CC_AMPERE 800
|
||||||
#define CC_OFFSET_AMD 1000000
|
#define CC_OFFSET_AMD 1000000
|
||||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||||
|
@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_CUDA_F16
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
#endif // GGML_CUDA_F16
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
//#pragma unroll
|
#pragma unroll
|
||||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
// }
|
}
|
||||||
// return x;
|
return x;
|
||||||
//#else
|
#else
|
||||||
// GGML_UNUSED(x);
|
GGML_UNUSED(x);
|
||||||
// NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
//}
|
}
|
||||||
|
|
||||||
|
#if CUDART_VERSION < 12000
|
||||||
|
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
||||||
|
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
||||||
|
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
||||||
|
return mask_low | mask_high;
|
||||||
|
}
|
||||||
|
#endif // CUDART_VERSION < 12000
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||||
}
|
}
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
|
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
||||||
|
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
|
||||||
|
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
|
||||||
// TODO: move to ggml-common.h
|
// TODO: move to ggml-common.h
|
||||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||||
|
|
||||||
|
@ -404,6 +415,7 @@ struct ggml_cuda_device_info {
|
||||||
|
|
||||||
struct cuda_device_info {
|
struct cuda_device_info {
|
||||||
int cc; // compute capability
|
int cc; // compute capability
|
||||||
|
int nsm; // number of streaming multiprocessors
|
||||||
size_t smpb; // max. shared memory per block
|
size_t smpb; // max. shared memory per block
|
||||||
bool vmm; // virtual memory support
|
bool vmm; // virtual memory support
|
||||||
size_t vmm_granularity; // granularity of virtual memory
|
size_t vmm_granularity; // granularity of virtual memory
|
||||||
|
|
|
@ -5,16 +5,16 @@
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||||
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ib = i/qk; // block index
|
const int64_t ib = i/qk; // block index
|
||||||
const int iqs = (i%qk)/qr; // quant index
|
const int64_t iqs = (i%qk)/qr; // quant index
|
||||||
const int iybs = i - i%qk; // y block start index
|
const int64_t iybs = i - i%qk; // y block start index
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
dfloat2 v;
|
dfloat2 v;
|
||||||
|
@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
|
||||||
#if __CUDA_ARCH__ >= CC_PASCAL
|
#if __CUDA_ARCH__ >= CC_PASCAL
|
||||||
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
|
||||||
|
|
||||||
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
|
||||||
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
const int * x0 = ((int *) vx) + blockIdx.x * nint;
|
||||||
half2 * y2 = (half2 *) (y + i0);
|
half2 * y2 = (half2 *) (y + i0);
|
||||||
|
|
||||||
|
@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
|
||||||
const int64_t i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int64_t ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
|
@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
|
||||||
const int64_t i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int64_t ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
|
@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_q2_K * x = (const block_q2_K *) vx;
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int n = tid/32;
|
const int64_t n = tid/32;
|
||||||
const int l = tid - 32*n;
|
const int64_t l = tid - 32*n;
|
||||||
const int is = 8*n + l/16;
|
const int64_t is = 8*n + l/16;
|
||||||
|
|
||||||
const uint8_t q = x[i].qs[32*n + l];
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
dst_t * y = yy + i*QK_K + 128*n;
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
|
@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||||
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||||
#else
|
#else
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
float dall = __low2half(x[i].dm);
|
float dall = __low2half(x[i].dm);
|
||||||
|
@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int r = threadIdx.x/4;
|
const int64_t r = threadIdx.x/4;
|
||||||
const int tid = r/2;
|
const int64_t tid = r/2;
|
||||||
const int is0 = r%2;
|
const int64_t is0 = r%2;
|
||||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||||
const int n = tid / 4;
|
const int64_t n = tid / 4;
|
||||||
const int j = tid - 4*n;
|
const int64_t j = tid - 4*n;
|
||||||
|
|
||||||
uint8_t m = 1 << (4*n + j);
|
uint8_t m = 1 << (4*n + j);
|
||||||
int is = 8*n + 2*j + is0;
|
int64_t is = 8*n + 2*j + is0;
|
||||||
int shift = 2*j;
|
int shift = 2*j;
|
||||||
|
|
||||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
|
@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
||||||
|
|
||||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const int im = il/8; // 0...1
|
const int64_t im = il/8; // 0...1
|
||||||
const int in = il%8; // 0...7
|
const int64_t in = il%8; // 0...7
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
@ -227,15 +227,15 @@ template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int is = 2*il;
|
const int64_t is = 2*il;
|
||||||
const int n = 4;
|
const int64_t n = 4;
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
||||||
y[l +32] = d2 * (q[l] >> 4) - m2;
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const uint8_t * q = x[i].qs;
|
const uint8_t * q = x[i].qs;
|
||||||
dst_t * y = yy + i*QK_K;
|
dst_t * y = yy + i*QK_K;
|
||||||
const float d = (float)x[i].dm[0];
|
const float d = (float)x[i].dm[0];
|
||||||
|
@ -268,14 +268,14 @@ template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q5_K * x = (const block_q5_K *) vx;
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/16; // il is in 0...3
|
const int64_t il = tid/16; // il is in 0...3
|
||||||
const int ir = tid%16; // ir is in 0...15
|
const int64_t ir = tid%16; // ir is in 0...15
|
||||||
const int is = 2*il; // is is in 0...6
|
const int64_t is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
|
@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
||||||
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
||||||
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||||
#else
|
#else
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const uint8_t q = x[i].qs[tid];
|
const uint8_t q = x[i].qs[tid];
|
||||||
const int im = tid/8; // 0...3
|
const int64_t im = tid/8; // 0...3
|
||||||
const int in = tid%8; // 0...7
|
const int64_t in = tid%8; // 0...7
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const uint8_t h = x[i].qh[in] >> im;
|
const uint8_t h = x[i].qh[in] >> im;
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
dst_t * y = yy + i*QK_K + tid;
|
dst_t * y = yy + i*QK_K + tid;
|
||||||
|
@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
|
@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * q3 = x[i].qs + 8*ib;
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||||
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||||
|
@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * qs = x[i].qs + 8*ib;
|
const uint8_t * qs = x[i].qs + 8*ib;
|
||||||
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||||
|
@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||||
|
@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq1_m * x = (const block_iq1_m *) vx;
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
iq1m_scale_t scale;
|
iq1m_scale_t scale;
|
||||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
|
||||||
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
|
||||||
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
|
||||||
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
|
@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
const float d = (float)x[ib].d;
|
const float d = (float)x[ib].d;
|
||||||
|
@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
||||||
#if QK_K != 64
|
#if QK_K != 64
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const int i = blockIdx.x;
|
const int64_t i = blockIdx.x;
|
||||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int64_t tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
|
|
944
ggml-cuda/fattn.cu
Normal file
944
ggml-cuda/fattn.cu
Normal file
|
@ -0,0 +1,944 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "fattn.cuh"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#if FP16_MMA_AVAILABLE
|
||||||
|
#include <mma.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define FATTN_KQ_STRIDE 256
|
||||||
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||||
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||||
|
|
||||||
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
||||||
|
static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
const char * __restrict__ Q,
|
||||||
|
const char * __restrict__ K,
|
||||||
|
const char * __restrict__ V,
|
||||||
|
const char * __restrict__ mask,
|
||||||
|
float * __restrict__ dst,
|
||||||
|
float2 * __restrict__ dst_meta,
|
||||||
|
const float scale,
|
||||||
|
const int ne00,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int ne03,
|
||||||
|
const int ne10,
|
||||||
|
const int ne11,
|
||||||
|
const int ne12,
|
||||||
|
const int ne13,
|
||||||
|
const int ne31,
|
||||||
|
const int nb31,
|
||||||
|
const int nb01,
|
||||||
|
const int nb02,
|
||||||
|
const int nb03,
|
||||||
|
const int nb11,
|
||||||
|
const int nb12,
|
||||||
|
const int nb13,
|
||||||
|
const int ne0,
|
||||||
|
const int ne1,
|
||||||
|
const int ne2,
|
||||||
|
const int ne3) {
|
||||||
|
#if FP16_AVAILABLE
|
||||||
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
|
const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
|
||||||
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
|
||||||
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
|
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
|
const half * maskh = (const half *) mask + ne11*ic;
|
||||||
|
|
||||||
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
__builtin_assume(tid < nwarps*WARP_SIZE);
|
||||||
|
|
||||||
|
__shared__ half KQ[nwarps*WARP_SIZE];
|
||||||
|
KQ[tid] = -INFINITY;
|
||||||
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
|
half kqmax = -HALF_MAX_HALF;
|
||||||
|
half kqsum = 0.0f;
|
||||||
|
|
||||||
|
__shared__ half kqmax_shared[WARP_SIZE];
|
||||||
|
__shared__ half kqsum_shared[WARP_SIZE];
|
||||||
|
if (threadIdx.y == 0) {
|
||||||
|
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
|
||||||
|
kqsum_shared[threadIdx.x] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Convert Q to half2 and store in registers:
|
||||||
|
half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
|
||||||
|
|
||||||
|
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||||
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||||
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
half kqmax_new = kqmax;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 sum2 = make_half2(0.0f, 0.0f);
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||||
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||||
|
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
|
||||||
|
sum2 = warp_reduce_sum(sum2);
|
||||||
|
half sum = __low2half(sum2) + __high2half(sum2);
|
||||||
|
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
kqmax_new = __hmax(kqmax_new, sum);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
KQ[i_KQ] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kqmax_new = warp_reduce_max(kqmax_new);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[threadIdx.y] = kqmax_new;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
kqmax_new = kqmax_shared[threadIdx.x];
|
||||||
|
kqmax_new = warp_reduce_max(kqmax_new);
|
||||||
|
|
||||||
|
const half KQ_max_scale = hexp(kqmax - kqmax_new);
|
||||||
|
kqmax = kqmax_new;
|
||||||
|
|
||||||
|
const half val = hexp(KQ[tid] - kqmax);
|
||||||
|
kqsum = kqsum*KQ_max_scale + val;
|
||||||
|
KQ[tid] = val;
|
||||||
|
|
||||||
|
VKQ *= __half2half2(KQ_max_scale);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tid < D) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||||
|
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 V_k;
|
||||||
|
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
||||||
|
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
||||||
|
VKQ += V_k*KQ2[k0/2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid >= D) {
|
||||||
|
kqsum = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
kqsum = warp_reduce_sum(kqsum);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqsum_shared[threadIdx.y] = kqsum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
kqsum = kqsum_shared[threadIdx.x];
|
||||||
|
kqsum = warp_reduce_sum(kqsum);
|
||||||
|
|
||||||
|
if (tid >= D) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
half dst_val = (__low2half(VKQ) + __high2half(VKQ));
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
dst_val /= kqsum;
|
||||||
|
}
|
||||||
|
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
|
||||||
|
|
||||||
|
if (parallel_blocks == 1 || tid != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||||
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
|
static __global__ void flash_attn_ext_f16(
|
||||||
|
const char * __restrict__ Q,
|
||||||
|
const char * __restrict__ K,
|
||||||
|
const char * __restrict__ V,
|
||||||
|
const char * __restrict__ mask,
|
||||||
|
float * __restrict__ dst,
|
||||||
|
float2 * __restrict__ dst_meta,
|
||||||
|
const float scale,
|
||||||
|
const int ne00,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int ne03,
|
||||||
|
const int ne10,
|
||||||
|
const int ne11,
|
||||||
|
const int ne12,
|
||||||
|
const int ne13,
|
||||||
|
const int ne31,
|
||||||
|
const int nb31,
|
||||||
|
const int nb01,
|
||||||
|
const int nb02,
|
||||||
|
const int nb03,
|
||||||
|
const int nb11,
|
||||||
|
const int nb12,
|
||||||
|
const int nb13,
|
||||||
|
const int ne0,
|
||||||
|
const int ne1,
|
||||||
|
const int ne2,
|
||||||
|
const int ne3) {
|
||||||
|
#if FP16_MMA_AVAILABLE
|
||||||
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||||
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
|
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||||
|
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||||
|
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||||
|
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||||
|
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||||
|
|
||||||
|
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||||
|
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||||
|
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
||||||
|
|
||||||
|
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||||
|
constexpr int D_padded = D + 8;
|
||||||
|
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||||
|
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||||
|
|
||||||
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
|
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||||
|
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
|
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
|
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||||
|
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||||
|
|
||||||
|
const int stride_Q = nb01 / sizeof(float);
|
||||||
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
|
|
||||||
|
frag_b Q_b[D/16][ncols/frag_n];
|
||||||
|
|
||||||
|
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||||
|
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
||||||
|
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||||
|
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||||
|
float * KQ_f = (float *) KQ;
|
||||||
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
|
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
||||||
|
float KQ_max_f[ncols/nwarps];
|
||||||
|
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||||
|
KQ_max_f[j] = -FLT_MAX/2.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||||
|
half2 KQ_max_h2[ncols/nwarps];
|
||||||
|
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||||
|
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
||||||
|
}
|
||||||
|
|
||||||
|
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||||
|
half2 * VKQ2 = (half2 *) VKQ;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert Q to half and apply scale, temporarily store in KQ:
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if (i0 + WARP_SIZE > D && i >= D) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Load Q into tensor core fragments/registers since it will be used frequently:
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Iterate over ne11 == previous tokens:
|
||||||
|
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
||||||
|
// Calculate tile of KQ:
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||||
|
frag_c_KQ KQ_c[ncols/frag_n];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
|
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||||
|
frag_a_K K_a;
|
||||||
|
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
|
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
|
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Calculate softmax for each KQ column using the current max. value.
|
||||||
|
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
|
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||||
|
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
||||||
|
}
|
||||||
|
KQ_max_new = warp_reduce_max(KQ_max_new);
|
||||||
|
|
||||||
|
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
||||||
|
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
||||||
|
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||||
|
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
||||||
|
}
|
||||||
|
KQ_max_f[j0/nwarps] = KQ_max_new;
|
||||||
|
|
||||||
|
float KQ_rowsum_add = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
||||||
|
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
||||||
|
}
|
||||||
|
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
||||||
|
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||||
|
|
||||||
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||||
|
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
||||||
|
} else {
|
||||||
|
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||||
|
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||||
|
}
|
||||||
|
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||||
|
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||||
|
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||||
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
|
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
||||||
|
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
||||||
|
|
||||||
|
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||||
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
|
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||||
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||||
|
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||||
|
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
||||||
|
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||||
|
|
||||||
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||||
|
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||||
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
|
nvcuda::wmma::load_matrix_sync(
|
||||||
|
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||||
|
KQ + j0*(kqar*kqs_padded) + k,
|
||||||
|
kqar*kqs_padded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
|
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||||
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||||
|
|
||||||
|
frag_a_V v_a;
|
||||||
|
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||||
|
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||||
|
nvcuda::wmma::store_matrix_sync(
|
||||||
|
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||||
|
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||||
|
D_padded, nvcuda::wmma::mem_col_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
half2 VKQ_scale;
|
||||||
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
|
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
||||||
|
} else {
|
||||||
|
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||||
|
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||||
|
}
|
||||||
|
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j_VKQ = j0 + threadIdx.y;
|
||||||
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
|
||||||
|
float KQ_rowsum_j;
|
||||||
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
|
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
||||||
|
} else {
|
||||||
|
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
if (i0 + WARP_SIZE > D && i >= D) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
dst_val /= KQ_rowsum_j;
|
||||||
|
}
|
||||||
|
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float2 dst_meta_val;
|
||||||
|
if (std::is_same<KQ_acc_t, float>::value) {
|
||||||
|
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
||||||
|
} else {
|
||||||
|
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||||
|
}
|
||||||
|
dst_meta_val.y = KQ_rowsum_j;
|
||||||
|
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
__launch_bounds__(D, 1)
|
||||||
|
static __global__ void flash_attn_combine_results(
|
||||||
|
const float * __restrict__ VKQ_parts,
|
||||||
|
const float2 * __restrict__ VKQ_meta,
|
||||||
|
float * __restrict__ dst) {
|
||||||
|
#if FP16_AVAILABLE
|
||||||
|
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
||||||
|
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
||||||
|
dst += D * gridDim.y*blockIdx.x;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
|
__shared__ float2 meta[parallel_blocks];
|
||||||
|
if (tid < 2*parallel_blocks) {
|
||||||
|
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float kqmax = meta[0].x;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 1; l < parallel_blocks; ++l) {
|
||||||
|
kqmax = max(kqmax, meta[l].x);
|
||||||
|
}
|
||||||
|
|
||||||
|
float VKQ_numerator = 0.0f;
|
||||||
|
float VKQ_denominator = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < parallel_blocks; ++l) {
|
||||||
|
const float diff = meta[l].x - kqmax;
|
||||||
|
const float KQ_max_scale = expf(diff);
|
||||||
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||||
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||||
|
|
||||||
|
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
||||||
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int get_max_power_of_2(int x) {
|
||||||
|
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
||||||
|
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
||||||
|
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
||||||
|
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
||||||
|
|
||||||
|
// Number of VKQ rows calculated in parallel:
|
||||||
|
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
||||||
|
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||||
|
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
|
|
||||||
|
template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
||||||
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
|
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
|
) {
|
||||||
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||||
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||||
|
|
||||||
|
if (parallel_blocks > 1) {
|
||||||
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||||
|
const int shmem = 0;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
|
flash_attn_vec_ext_f16<D, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data,
|
||||||
|
(const char *) K->data,
|
||||||
|
(const char *) V->data,
|
||||||
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
|
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||||
|
const int shmem_combine = 0;
|
||||||
|
|
||||||
|
flash_attn_combine_results<D, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
||||||
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
|
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
|
) {
|
||||||
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||||
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||||
|
|
||||||
|
if (parallel_blocks > 1) {
|
||||||
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
|
||||||
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||||
|
const int shmem = 0;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
|
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data,
|
||||||
|
(const char *) K->data,
|
||||||
|
(const char *) V->data,
|
||||||
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if ((parallel_blocks) == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
|
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||||
|
const int shmem_combine = 0;
|
||||||
|
|
||||||
|
flash_attn_combine_results<D, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
|
||||||
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
|
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
|
) {
|
||||||
|
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||||
|
|
||||||
|
if (4*blocks_num_pb1 < 2*nsm) {
|
||||||
|
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (2*blocks_num_pb1 < 2*nsm) {
|
||||||
|
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||||
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||||
|
|
||||||
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||||
|
|
||||||
|
ggml_cuda_set_device(ctx.device);
|
||||||
|
|
||||||
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
|
||||||
|
const int32_t precision = KQV->op_params[1];
|
||||||
|
|
||||||
|
if (precision != GGML_PREC_DEFAULT) {
|
||||||
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
|
constexpr int cols_per_block = 16;
|
||||||
|
constexpr int nwarps = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int nwarps = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
// case 256:
|
||||||
|
// launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
// break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||||
|
constexpr int cols_per_block = 8;
|
||||||
|
constexpr int nwarps = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 32) {
|
||||||
|
constexpr int cols_per_block = 16;
|
||||||
|
constexpr int nwarps = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int nwarps = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
3
ggml-cuda/fattn.cuh
Normal file
3
ggml-cuda/fattn.cuh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,7 +1,17 @@
|
||||||
#include "softmax.cuh"
|
#include "softmax.cuh"
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <typename T>
|
||||||
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
static __device__ __forceinline__ float t2f32(T val) {
|
||||||
|
return (float) val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ float __forceinline__ t2f32<half>(half val) {
|
||||||
|
return __half2float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||||
|
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
@ -28,7 +38,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
// shared memory buffer to cache values between iterations:
|
// shared memory buffer to cache values between iterations:
|
||||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
|
@ -40,10 +50,10 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ix = rowx*ncols + col;
|
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
|
@ -109,12 +119,13 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int idst = rowx*ncols + col;
|
const int64_t idst = (int64_t)rowx*ncols + col;
|
||||||
dst[idst] = vals[col] * inv_sum;
|
dst[idst] = vals[col] * inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
template<typename T>
|
||||||
|
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
|
||||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
const float * src0_d = (const float *)src0->data;
|
const float * src0_d = (const float *)src0->data;
|
||||||
const float * src1_d = src1 ? (const float *)src1->data : nullptr;
|
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
||||||
|
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
// positions tensor
|
// positions tensor
|
||||||
float * src2_dd = nullptr;
|
void * src2_d = nullptr;
|
||||||
|
|
||||||
ggml_tensor * src2 = dst->src[2];
|
|
||||||
const bool use_src2 = src2 != nullptr;
|
const bool use_src2 = src2 != nullptr;
|
||||||
|
|
||||||
if (use_src2) {
|
if (use_src2) {
|
||||||
src2_dd = (float *)src2->data;
|
src2_d = (void *)src2->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
if (use_f16) {
|
||||||
|
const half * src1_dd = (const half *)src1_d;
|
||||||
|
const half * src2_dd = (const half *)src2_d;
|
||||||
|
|
||||||
|
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||||
|
} else {
|
||||||
|
const float * src1_dd = (const float *)src1_d;
|
||||||
|
const float * src2_dd = (const float *)src2_d;
|
||||||
|
|
||||||
|
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -313,7 +313,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||||
|
|
||||||
#endif // defined(__ARM_NEON)
|
#endif // defined(__ARM_NEON)
|
||||||
|
|
||||||
#if defined(__ARM_NEON) && !defined(__MSC_VER)
|
#if defined(__ARM_NEON) && !defined(_MSC_VER)
|
||||||
|
|
||||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||||
|
|
|
@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
for (int i = node_start; i < node_end; ++i) {
|
for (int i = node_start; i < node_end; ++i) {
|
||||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
||||||
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
struct ggml_tensor * dst = gf->nodes[i];
|
||||||
GGML_ASSERT(dst->data != nullptr);
|
GGML_ASSERT(dst->data != nullptr);
|
||||||
|
|
||||||
|
@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
{
|
{
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
|
||||||
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src2 == nullptr);
|
||||||
|
|
||||||
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
|
268
ggml-metal.m
268
ggml-metal.m
|
@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU,
|
GGML_METAL_KERNEL_TYPE_SILU,
|
||||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
||||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
||||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
||||||
|
@ -177,6 +179,14 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
|
@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
||||||
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
||||||
(int) kernel->pipeline.threadExecutionWidth); \
|
(int) kernel->pipeline.threadExecutionWidth); \
|
||||||
*/
|
*/
|
||||||
|
@ -459,7 +469,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
return NULL; \
|
return NULL; \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
|
GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
||||||
}
|
}
|
||||||
|
|
||||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||||
|
@ -481,8 +491,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||||
|
@ -612,6 +624,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
|
@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
while (nth < ne00/4 && nth < 256) {
|
while (nth < ne00/4 && nth < 256) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
if (use_f16) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
while (nth < ne00 && nth < 1024) {
|
while (nth < ne00 && nth < 1024) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
}
|
}
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
if (use_f16) {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
|
@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
||||||
|
GGML_ASSERT(src3);
|
||||||
|
|
||||||
|
size_t offs_src3 = 0;
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
|
|
||||||
|
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||||
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||||
|
|
||||||
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||||
|
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||||
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||||
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
||||||
|
|
||||||
|
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
||||||
|
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
||||||
|
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
||||||
|
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
||||||
|
|
||||||
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ASSERT(false && "add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
use_vec_kernel = true;
|
||||||
|
|
||||||
|
switch (ne00) {
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ASSERT(false && "add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
||||||
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
||||||
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
||||||
|
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
||||||
|
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
||||||
|
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
||||||
|
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
||||||
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
||||||
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
||||||
|
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
||||||
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
||||||
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
|
if (!use_vec_kernel) {
|
||||||
|
// half8x8 kernel
|
||||||
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
GGML_ASSERT(nqptg <= 32);
|
||||||
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nsgmax *= 2;
|
||||||
|
}
|
||||||
|
nsgmax /= 2;
|
||||||
|
|
||||||
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||||
|
|
||||||
|
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
} else {
|
||||||
|
// half1x4 kernel
|
||||||
|
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
GGML_ASSERT(nqptg <= 32);
|
||||||
|
GGML_ASSERT(nqptg % 1 == 0);
|
||||||
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
|
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||||
|
|
||||||
|
int64_t nsg = 1;
|
||||||
|
while (nsg <= nsgt) {
|
||||||
|
nsg *= 2;
|
||||||
|
}
|
||||||
|
nsg /= 2;
|
||||||
|
|
||||||
|
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||||
|
|
||||||
|
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
@ -2590,6 +2779,45 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [command_buffer status];
|
||||||
if (status != MTLCommandBufferStatusCompleted) {
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||||
|
if (status == MTLCommandBufferStatusError) {
|
||||||
|
MTLCommandBufferError error_code = [command_buffer error].code;
|
||||||
|
switch (error_code) {
|
||||||
|
case MTLCommandBufferErrorNone:
|
||||||
|
GGML_METAL_LOG_INFO("no error code reported\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorTimeout:
|
||||||
|
GGML_METAL_LOG_INFO("timeout\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorPageFault:
|
||||||
|
GGML_METAL_LOG_INFO("unserviceable page fault\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorOutOfMemory:
|
||||||
|
GGML_METAL_LOG_INFO("out of memory\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorInvalidResource:
|
||||||
|
GGML_METAL_LOG_INFO("invalid reference to resource\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorMemoryless:
|
||||||
|
GGML_METAL_LOG_INFO("GPU ran out of one or more of its internal resources that support memoryless render pass attachments\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorDeviceRemoved:
|
||||||
|
GGML_METAL_LOG_INFO("device removed\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorStackOverflow:
|
||||||
|
GGML_METAL_LOG_INFO("kernel function of tile shader used too many stack frames\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorAccessRevoked:
|
||||||
|
GGML_METAL_LOG_INFO("access to device revoked by system\n");
|
||||||
|
break;
|
||||||
|
case MTLCommandBufferErrorInternal:
|
||||||
|
GGML_METAL_LOG_INFO("internal error\n");
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_METAL_LOG_INFO("unknown error %lu\n", error_code);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2706,10 +2934,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
|
||||||
UNUSED(buft);
|
UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
||||||
|
#ifndef GGML_METAL_NDEBUG
|
||||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||||
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
|
||||||
|
__func__,
|
||||||
|
size_aligned / 1024.0 / 1024.0,
|
||||||
device.currentAllocatedSize / 1024.0 / 1024.0,
|
device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||||
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
@ -2719,10 +2950,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
||||||
GGML_METAL_LOG_INFO("\n");
|
GGML_METAL_LOG_INFO("\n");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
||||||
|
__func__,
|
||||||
|
size_aligned / 1024.0 / 1024.0,
|
||||||
|
device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
UNUSED(device);
|
UNUSED(device);
|
||||||
|
UNUSED(size_aligned);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
@ -2756,8 +2992,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||||
ggml_backend_metal_log_allocated_size(device);
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
@ -2844,7 +3079,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||||
|
|
||||||
++ctx->n_buffers;
|
++ctx->n_buffers;
|
||||||
} else {
|
} else {
|
||||||
|
@ -2867,7 +3102,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
|
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
||||||
|
|
||||||
if (i + size_step < size) {
|
if (i + size_step < size) {
|
||||||
GGML_METAL_LOG_INFO("\n");
|
GGML_METAL_LOG_INFO("\n");
|
||||||
}
|
}
|
||||||
|
@ -2876,8 +3112,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_metal_log_allocated_size(device);
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
672
ggml-metal.metal
672
ggml-metal.metal
|
@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
|
||||||
dst_row[0] = row_sum;
|
dst_row[0] = row_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device const float * src2,
|
device const char * src2,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
@ -375,10 +376,10 @@ kernel void kernel_soft_max(
|
||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||||
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 0.0f;
|
||||||
|
|
||||||
|
@ -456,11 +457,12 @@ kernel void kernel_soft_max(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device const float * src2,
|
device const char * src2,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
|
||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||||
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
|
|
||||||
float slope = 0.0f;
|
float slope = 0.0f;
|
||||||
|
|
||||||
|
@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||||
|
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||||
|
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||||
|
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||||
|
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
|
||||||
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef void (flash_attn_ext_f16_t)(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne31,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant float & scale,
|
||||||
|
threadgroup half * shared,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||||
|
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
|
kernel void kernel_flash_attn_ext_f16(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne31,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant float & scale,
|
||||||
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
|
const short iq3 = tgpig[2];
|
||||||
|
const short iq2 = tgpig[1];
|
||||||
|
const short iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
|
const short D4 = D/4;
|
||||||
|
const short D8 = D/8;
|
||||||
|
const short Q8 = Q/8;
|
||||||
|
const short NW = N_SIMDWIDTH;
|
||||||
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
const short TF = T/2; // shared memory size per query in (float)
|
||||||
|
const short T4 = T/4; // shared memory size per query in (half4)
|
||||||
|
|
||||||
|
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
|
||||||
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
|
simdgroup_half8x8 lo[D8];
|
||||||
|
|
||||||
|
// load heads from Q to shared memory
|
||||||
|
for (short j = sgitg; j < Q; j += nsg) {
|
||||||
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
if (iq1 + j < ne01) {
|
||||||
|
sq4[j*T4 + i] = (half4) q4[i];
|
||||||
|
} else {
|
||||||
|
sq4[j*T4 + i] = 0.0h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out lo
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out shared memory SH
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
for (short i = tiisg; i < SH; i += NW) {
|
||||||
|
ss[j*TF + i] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
{
|
||||||
|
float S[Q] = { [0 ... Q-1] = 0.0h };
|
||||||
|
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
||||||
|
|
||||||
|
// assume K and V are same shape
|
||||||
|
const short ne22 = ne12;
|
||||||
|
const short ne23 = ne13;
|
||||||
|
|
||||||
|
const uint nb21 = nb11;
|
||||||
|
const uint nb22 = nb12;
|
||||||
|
const uint nb23 = nb13;
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const short rk2 = ne02/ne12;
|
||||||
|
const short rk3 = ne03/ne13;
|
||||||
|
|
||||||
|
const short rv2 = ne02/ne22;
|
||||||
|
const short rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const short ik2 = iq2/rk2;
|
||||||
|
const short ik3 = iq3/rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const short iv2 = iq2/rv2;
|
||||||
|
const short iv3 = iq3/rv3;
|
||||||
|
|
||||||
|
// load the queries from shared memory into local memory
|
||||||
|
simdgroup_half8x8 mq[D8];
|
||||||
|
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_load(mq[i], sq + i*8, T);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointer to the mask
|
||||||
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
|
// prepare diagonal scale matrix
|
||||||
|
simdgroup_float8x8 mscale(scale);
|
||||||
|
|
||||||
|
// loop over the KV cache
|
||||||
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||||
|
const int ic = ic0 + C*sgitg;
|
||||||
|
if (ic >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q*K^T
|
||||||
|
{
|
||||||
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
|
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||||
|
|
||||||
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
|
}
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask
|
||||||
|
simdgroup_half8x8 mm;
|
||||||
|
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
|
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||||
|
|
||||||
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to detect blocks full of -INF
|
||||||
|
float smax = -INFINITY;
|
||||||
|
|
||||||
|
// online softmax
|
||||||
|
{
|
||||||
|
float ms[Q];
|
||||||
|
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
const short p = tiisg;
|
||||||
|
|
||||||
|
const float m = M[j];
|
||||||
|
const float s = ss[j*TF + p];
|
||||||
|
|
||||||
|
smax = simd_max(max(smax, s));
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
|
ms[j] = exp(m - M[j]);
|
||||||
|
const float vs = exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[j*TF + p] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (tiisg < Q) {
|
||||||
|
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip -INF blocks
|
||||||
|
if (smax == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
{
|
||||||
|
simdgroup_float8x8 mm;
|
||||||
|
simdgroup_load(mm, ss + C, TF, 0, false);
|
||||||
|
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = O + (Q*K^T)*V
|
||||||
|
{
|
||||||
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
|
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||||
|
|
||||||
|
simdgroup_float8x8 mv;
|
||||||
|
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[j*TF + 0] = S[j];
|
||||||
|
ss[j*TF + 1] = M[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce the warps sequentially
|
||||||
|
for (short sg = 1; sg < nsg; ++sg) {
|
||||||
|
float S = { 0.0h };
|
||||||
|
float M = { -FLT_MAX/2 };
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
|
if (sgitg == sg) {
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
const float S0 = ss[j*TF + 0];
|
||||||
|
const float S1 = ss[j*TF + sg*SH + 0];
|
||||||
|
|
||||||
|
const float M0 = ss[j*TF + 1];
|
||||||
|
const float M1 = ss[j*TF + sg*SH + 1];
|
||||||
|
|
||||||
|
M = max(M0, M1);
|
||||||
|
|
||||||
|
const float ms0 = exp(M0 - M);
|
||||||
|
const float ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[j*TF + 0] = S;
|
||||||
|
ss[j*TF + 1] = M;
|
||||||
|
|
||||||
|
ss[j*TF + C + j ] = ms0;
|
||||||
|
ss[j*TF + C + j + sg*SH] = ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
{
|
||||||
|
simdgroup_half8x8 t;
|
||||||
|
simdgroup_float8x8 ms0;
|
||||||
|
simdgroup_float8x8 ms1;
|
||||||
|
|
||||||
|
simdgroup_load(ms0, ss + C, TF, 0, false);
|
||||||
|
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
||||||
|
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||||
|
simdgroup_multiply(t, ms1, t);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store result to shared memory (reuse sq)
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device float4 * dst4 = (device float4 *) dst;
|
||||||
|
|
||||||
|
// final rescale with 1/S and store to global memory
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
|
const float S = ss[j*TF + 0];
|
||||||
|
|
||||||
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
||||||
|
|
||||||
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
|
kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne31,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant float & scale,
|
||||||
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
|
const short iq3 = tgpig[2];
|
||||||
|
const short iq2 = tgpig[1];
|
||||||
|
const short iq1 = tgpig[0];
|
||||||
|
|
||||||
|
const short D4 = D/4;
|
||||||
|
const short NW = N_SIMDWIDTH;
|
||||||
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||||
|
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||||
|
|
||||||
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
|
half4 lo[D4/NW];
|
||||||
|
|
||||||
|
// load heads from Q to shared memory
|
||||||
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
if (iq1 < ne01) {
|
||||||
|
sq4[i] = (half4) q4[i];
|
||||||
|
} else {
|
||||||
|
sq4[i] = 0.0h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out lo
|
||||||
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
|
lo[i/NW] = 0.0h;
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out shared memory SH
|
||||||
|
for (short i = tiisg; i < SH/4; i += NW) {
|
||||||
|
ss4[i] = 0.0h;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
{
|
||||||
|
float S = { 0.0h };
|
||||||
|
float M = { -FLT_MAX/2 };
|
||||||
|
|
||||||
|
// assume K and V are same shape
|
||||||
|
const short ne22 = ne12;
|
||||||
|
const short ne23 = ne13;
|
||||||
|
|
||||||
|
const uint nb21 = nb11;
|
||||||
|
const uint nb22 = nb12;
|
||||||
|
const uint nb23 = nb13;
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const short rk2 = ne02/ne12;
|
||||||
|
const short rk3 = ne03/ne13;
|
||||||
|
|
||||||
|
const short rv2 = ne02/ne22;
|
||||||
|
const short rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const short ik2 = iq2 / rk2;
|
||||||
|
const short ik3 = iq3 / rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const short iv2 = iq2 / rv2;
|
||||||
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
// load the queries from shared memory into local memory
|
||||||
|
half4 mq[D4];
|
||||||
|
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
mq[i] = sq4[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointer to the mask
|
||||||
|
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||||
|
|
||||||
|
// loop over the KV cache
|
||||||
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||||
|
const int ic = ic0 + C*sgitg;
|
||||||
|
if (ic >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q*K^T
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
|
float4 mqk = { 0.0h };
|
||||||
|
|
||||||
|
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
|
half4x4 mk;
|
||||||
|
mk[0] = pk4[i + 0*(nb11/8)];
|
||||||
|
mk[1] = pk4[i + 1*(nb11/8)];
|
||||||
|
mk[2] = pk4[i + 2*(nb11/8)];
|
||||||
|
mk[3] = pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
|
mqk += (float4) (mq[i] * mk);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce the results from the threads in the simdgroup
|
||||||
|
mqk += simd_shuffle_down(mqk, 16);
|
||||||
|
mqk += simd_shuffle_down(mqk, 8);
|
||||||
|
mqk += simd_shuffle_down(mqk, 4);
|
||||||
|
mqk += simd_shuffle_down(mqk, 2);
|
||||||
|
mqk += simd_shuffle_down(mqk, 1);
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask
|
||||||
|
if (tiisg == 0) {
|
||||||
|
float4 mm = (float4) mp4[ic/4 + cc];
|
||||||
|
mqk = mqk*scale + mm;
|
||||||
|
|
||||||
|
ss4[cc] = mqk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// online softmax
|
||||||
|
{
|
||||||
|
const short p = tiisg;
|
||||||
|
|
||||||
|
const float m = M;
|
||||||
|
const float s = ss[p];
|
||||||
|
|
||||||
|
M = simd_max(max(M, s));
|
||||||
|
|
||||||
|
const float ms = exp(m - M);
|
||||||
|
const float vs = exp(s - M);
|
||||||
|
|
||||||
|
S = S*ms + simd_sum(vs);
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[p] = vs;
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
lo[i/NW] *= ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = O + (Q*K^T)*V
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (short cc = 0; cc < C/4; ++cc) {
|
||||||
|
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
|
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||||
|
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||||
|
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||||
|
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[0] = S;
|
||||||
|
ss[1] = M;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store results to shared memory
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
sr4[i] = lo[ii/NW];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// parallel reduce
|
||||||
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||||
|
if (sgitg < r) {
|
||||||
|
const float S0 = ss[ 0];
|
||||||
|
const float S1 = ss[r*SH + 0];
|
||||||
|
|
||||||
|
const float M0 = ss[ 1];
|
||||||
|
const float M1 = ss[r*SH + 1];
|
||||||
|
|
||||||
|
const float M = max(M0, M1);
|
||||||
|
|
||||||
|
const float ms0 = exp(M0 - M);
|
||||||
|
const float ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
const float S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[0] = S;
|
||||||
|
ss[1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
device float4 * dst4 = (device float4 *) dst;
|
||||||
|
|
||||||
|
// final rescale with 1/S and store to global memory
|
||||||
|
if (sgitg == 0) {
|
||||||
|
const float S = ss[0];
|
||||||
|
|
||||||
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
|
284
ggml-quants.c
284
ggml-quants.c
|
@ -12384,3 +12384,287 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
block_iq2_s * restrict y = vy;
|
block_iq2_s * restrict y = vy;
|
||||||
quantize_row_iq2_s_reference(x, y, k);
|
quantize_row_iq2_s_reference(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool validate_float(float f, size_t i) {
|
||||||
|
if (isinf(f)) {
|
||||||
|
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isnan(f)) {
|
||||||
|
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isinf_fp16(ggml_fp16_t f) {
|
||||||
|
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isnan_fp16(ggml_fp16_t f) {
|
||||||
|
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
||||||
|
if (isinf_fp16(f)) {
|
||||||
|
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isnan_fp16(f)) {
|
||||||
|
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
|
||||||
|
const type * q = (const type *) (data); \
|
||||||
|
for (size_t i = 0; i < (nb); ++i) { \
|
||||||
|
if (!validate_fp16(q[i].d, i)) { \
|
||||||
|
return false; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
|
||||||
|
const type * q = (const type *) (data); \
|
||||||
|
for (size_t i = 0; i < (nb); ++i) { \
|
||||||
|
if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
|
||||||
|
return false; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
|
||||||
|
if (type < 0 || type >= GGML_TYPE_COUNT) {
|
||||||
|
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nbytes % ggml_type_size(type) != 0) {
|
||||||
|
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nb = nbytes/ggml_type_size(type);
|
||||||
|
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
|
||||||
|
size_t i = 0;
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
for (; i + 15 < nb; i += 16) {
|
||||||
|
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
||||||
|
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
|
||||||
|
__m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
|
||||||
|
int mask = _mm256_movemask_epi8(cmp);
|
||||||
|
if (mask) {
|
||||||
|
for (size_t j = 0; j < 16; ++j) {
|
||||||
|
if (!validate_fp16(f[i + j], i + j)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_NEON)
|
||||||
|
for (; i + 7 < nb; i += 8) {
|
||||||
|
uint16x8_t v = vld1q_u16(f + i);
|
||||||
|
uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
|
||||||
|
uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
|
||||||
|
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
|
||||||
|
if (mask) {
|
||||||
|
for (size_t j = 0; j < 8; ++j) {
|
||||||
|
if (!validate_fp16(f[i + j], i + j)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < nb; ++i) {
|
||||||
|
if (!validate_fp16(f[i], i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
const float * f = (const float *) data;
|
||||||
|
size_t i = 0;
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
for (; i + 7 < nb; i += 8) {
|
||||||
|
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
|
||||||
|
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
|
||||||
|
__m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
|
||||||
|
int mask = _mm256_movemask_epi8(cmp);
|
||||||
|
if (mask) {
|
||||||
|
for (size_t j = 0; j < 8; ++j) {
|
||||||
|
if (!validate_float(f[i + j], i + j)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_NEON)
|
||||||
|
for (; i + 3 < nb; i += 4) {
|
||||||
|
uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
|
||||||
|
uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
|
||||||
|
uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
|
||||||
|
uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
|
||||||
|
if (mask) {
|
||||||
|
for (size_t j = 0; j < 4; ++j) {
|
||||||
|
if (!validate_float(f[i + j], i + j)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < nb; ++i) {
|
||||||
|
if (!validate_float(f[i], i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F64:
|
||||||
|
{
|
||||||
|
const double * f = (const double *) data;
|
||||||
|
for (size_t i = 0; i < nb; ++i) {
|
||||||
|
if (!validate_float(f[i], i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
{
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]);
|
||||||
|
#else
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
|
||||||
|
#endif
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
{
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb);
|
||||||
|
#else
|
||||||
|
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
|
||||||
|
#endif
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_K:
|
||||||
|
{
|
||||||
|
const block_q8_K * q = (const block_q8_K *) data;
|
||||||
|
for (size_t i = 0; i < nb; ++i) {
|
||||||
|
if (!validate_float(q[i].d, i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ1_S:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ1_M:
|
||||||
|
{
|
||||||
|
const block_iq1_m * q = (const block_iq1_m *) data;
|
||||||
|
for (size_t i = 0; i < nb; ++i) {
|
||||||
|
#if QK_K == 64
|
||||||
|
if (!validate_fp16(q[i].d, i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
iq1m_scale_t scale;
|
||||||
|
const uint16_t * sc = (const uint16_t *)q[i].scales;
|
||||||
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||||
|
if (!validate_fp16(scale.f16, i)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_XS:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_S:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ3_XXS:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
|
||||||
|
} break;
|
||||||
|
|
||||||
|
case GGML_TYPE_IQ3_S:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_IQ4_XS:
|
||||||
|
#if QK_K != 64
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
|
||||||
|
} break;
|
||||||
|
#endif
|
||||||
|
// with QK_K == 64, iq4_xs is iq4_nl
|
||||||
|
case GGML_TYPE_IQ4_NL:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_I8:
|
||||||
|
case GGML_TYPE_I16:
|
||||||
|
case GGML_TYPE_I32:
|
||||||
|
case GGML_TYPE_I64:
|
||||||
|
// nothing to validate
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
|
@ -13416,11 +13416,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||||
version += std::to_string(prop.get_minor_version());
|
version += std::to_string(prop.get_minor_version());
|
||||||
|
|
||||||
device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
|
device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
|
||||||
|
std::string name = std::string(prop.get_name());
|
||||||
|
name = std::regex_replace(name, std::regex("\\(R\\)"), "");
|
||||||
|
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
|
||||||
|
|
||||||
fprintf(stderr, "|%2d|%18s|%45s|%10s|%11d|%8d|%7d|%15lu|\n", id, device_type.c_str(),
|
auto global_mem_size = prop.get_global_mem_size()/1000000;
|
||||||
prop.get_name(), version.c_str(), prop.get_max_compute_units(),
|
|
||||||
|
fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
|
||||||
|
name.c_str(), version.c_str(), prop.get_max_compute_units(),
|
||||||
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
||||||
prop.get_global_mem_size());
|
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_sycl_print_sycl_devices() {
|
void ggml_backend_sycl_print_sycl_devices() {
|
||||||
|
@ -13428,9 +13433,10 @@ void ggml_backend_sycl_print_sycl_devices() {
|
||||||
int device_count = dpct::dev_mgr::instance().device_count();
|
int device_count = dpct::dev_mgr::instance().device_count();
|
||||||
std::map<std::string, size_t> DeviceNums;
|
std::map<std::string, size_t> DeviceNums;
|
||||||
fprintf(stderr, "found %d SYCL devices:\n", device_count);
|
fprintf(stderr, "found %d SYCL devices:\n", device_count);
|
||||||
fprintf(stderr, "| | | |Compute |Max compute|Max work|Max sub| |\n");
|
fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
|
||||||
fprintf(stderr, "|ID| Device Type| Name|capability|units |group |group |Global mem size|\n");
|
fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
|
||||||
fprintf(stderr, "|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|\n");
|
fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
|
||||||
|
fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
|
||||||
for (int id = 0; id < device_count; ++id) {
|
for (int id = 0; id < device_count; ++id) {
|
||||||
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
||||||
sycl::backend backend = device.get_backend();
|
sycl::backend backend = device.get_backend();
|
||||||
|
@ -14738,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
|
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
||||||
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
@ -14754,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
||||||
float * src2_dd = nullptr;
|
float * src2_dd = nullptr;
|
||||||
sycl_pool_alloc<float> src2_f;
|
sycl_pool_alloc<float> src2_f;
|
||||||
|
|
||||||
ggml_tensor * src2 = dst->src[2];
|
|
||||||
const bool use_src2 = src2 != nullptr;
|
const bool use_src2 = src2 != nullptr;
|
||||||
|
|
||||||
if (use_src2) {
|
if (use_src2) {
|
||||||
|
|
|
@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
|
||||||
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_soft_max_f32;
|
return ctx->device->pipeline_soft_max_f32;
|
||||||
}
|
}
|
||||||
|
|
418
ggml.c
418
ggml.c
|
@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
||||||
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
||||||
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
|
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
|
||||||
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
|
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
|
||||||
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
|
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
|
||||||
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
|
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
|
||||||
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
|
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
|
||||||
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
|
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
|
||||||
|
@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
||||||
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
||||||
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
||||||
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
||||||
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
|
||||||
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
||||||
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
||||||
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
||||||
|
@ -1046,7 +1046,7 @@ do { \
|
||||||
|
|
||||||
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
|
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
|
||||||
// so F16C guard isn't required
|
// so F16C guard isn't required
|
||||||
#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
|
#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
||||||
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
|
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
|
||||||
|
|
||||||
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
|
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
|
||||||
|
@ -1144,7 +1144,7 @@ do { \
|
||||||
|
|
||||||
#if defined(__F16C__)
|
#if defined(__F16C__)
|
||||||
// the _mm256_cvt intrinsics require F16C
|
// the _mm256_cvt intrinsics require F16C
|
||||||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
|
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||||
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
||||||
#else
|
#else
|
||||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||||
|
@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// xs and vs are byte strides of x and v
|
// xs and vs are byte strides of x and v
|
||||||
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
||||||
|
|
||||||
|
@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
||||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||||
|
@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
|
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
|
"FLASH_ATTN_EXT",
|
||||||
"FLASH_FF",
|
"FLASH_FF",
|
||||||
"FLASH_ATTN_BACK",
|
"FLASH_ATTN_BACK",
|
||||||
"SSM_CONV",
|
"SSM_CONV",
|
||||||
|
@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
|
|
||||||
"flash_attn(x)",
|
"flash_attn(x)",
|
||||||
|
"flash_attn_ext(x)",
|
||||||
"flash_ff(x)",
|
"flash_ff(x)",
|
||||||
"flash_attn_back(x)",
|
"flash_attn_back(x)",
|
||||||
"ssm_conv(x)",
|
"ssm_conv(x)",
|
||||||
|
@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
|
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -4560,6 +4622,8 @@ struct ggml_tensor * ggml_mul_mat(
|
||||||
void ggml_mul_mat_set_prec(
|
void ggml_mul_mat_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec) {
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 0, prec_i32);
|
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||||
|
@ -5398,17 +5462,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
GGML_ASSERT(ggml_is_contiguous(a));
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
|
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(ggml_is_matrix(mask));
|
GGML_ASSERT(ggml_is_matrix(mask));
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||||
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos) {
|
if (pos) {
|
||||||
GGML_ASSERT(ggml_is_vector(pos));
|
GGML_ASSERT(ggml_is_vector(pos));
|
||||||
GGML_ASSERT(pos->type == GGML_TYPE_F32);
|
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pos && mask) {
|
||||||
|
GGML_ASSERT(pos->type == mask->type);
|
||||||
|
}
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
GGML_ASSERT(pos);
|
GGML_ASSERT(pos);
|
||||||
}
|
}
|
||||||
|
@ -6217,6 +6287,59 @@ struct ggml_tensor * ggml_flash_attn(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_flash_attn_ext
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale) {
|
||||||
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
if (mask) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||||
|
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||||
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (q->grad || k->grad || v->grad) {
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// permute(0, 2, 1, 3)
|
||||||
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
float params[] = { scale };
|
||||||
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = q;
|
||||||
|
result->src[1] = k;
|
||||||
|
result->src[2] = v;
|
||||||
|
result->src[3] = mask;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_ff
|
// ggml_flash_ff
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_ff(
|
struct ggml_tensor * ggml_flash_ff(
|
||||||
|
@ -12256,7 +12379,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||||
|
|
||||||
// TODO: is this supposed to be ceil instead of floor?
|
// TODO: is this supposed to be ceil instead of floor?
|
||||||
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||||
|
@ -12279,19 +12402,31 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
||||||
float * pos = src2 ? (float *) src2->data : src0->data;
|
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
||||||
|
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
||||||
|
|
||||||
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
// broadcast the mask across rows
|
// broadcast the mask across rows
|
||||||
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||||
|
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||||
|
|
||||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||||
ggml_vec_scale_f32(nc, wp, scale);
|
ggml_vec_scale_f32(nc, wp, scale);
|
||||||
if (mp) {
|
if (mp_f32) {
|
||||||
ggml_vec_acc_f32(nc, wp, mp);
|
if (use_f16) {
|
||||||
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += mp_f32[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ALiBi bias
|
// ALiBi bias
|
||||||
|
@ -12299,8 +12434,14 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
const uint32_t h = (i1/ne01)%ne02; // head
|
const uint32_t h = (i1/ne01)%ne02; // head
|
||||||
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
||||||
|
|
||||||
for (int i = 0; i < nc; i++) {
|
if (use_f16) {
|
||||||
wp[i] = wp[i] + slope*pos[i];
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += slope*pos_f32[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14570,6 +14711,198 @@ static void ggml_compute_forward_flash_attn(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * mask,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
int64_t t0 = ggml_perf_time_us();
|
||||||
|
UNUSED(t0);
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t D = neq0;
|
||||||
|
const int64_t N = neq1;
|
||||||
|
|
||||||
|
GGML_ASSERT(ne0 == D);
|
||||||
|
GGML_ASSERT(ne2 == N);
|
||||||
|
|
||||||
|
GGML_ASSERT(nbq0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||||
|
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
|
GGML_ASSERT(neq0 == D);
|
||||||
|
GGML_ASSERT(nek0 == D);
|
||||||
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
|
GGML_ASSERT(neq1 == N);
|
||||||
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
|
// dst cannot be transposed or permuted
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb0 <= nb1);
|
||||||
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const int64_t rk2 = neq2/nek2;
|
||||||
|
const int64_t rk3 = neq3/nek3;
|
||||||
|
|
||||||
|
const int64_t rv2 = neq2/nev2;
|
||||||
|
const int64_t rv3 = neq3/nev3;
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_TYPE_INIT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parallelize by q rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
|
// total rows in q
|
||||||
|
const int nr = neq1*neq2*neq3;
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
|
||||||
|
// loop over n_batch and n_head
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// q indices
|
||||||
|
const int iq3 = ir/(neq2*neq1);
|
||||||
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||||
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||||
|
|
||||||
|
float S = 0.0f;
|
||||||
|
float M = -INFINITY;
|
||||||
|
|
||||||
|
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
||||||
|
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
|
||||||
|
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
||||||
|
|
||||||
|
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const int ik3 = iq3 / rk3;
|
||||||
|
const int ik2 = iq2 / rk2;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const int iv3 = iq3 / rv3;
|
||||||
|
const int iv2 = iq2 / rv2;
|
||||||
|
|
||||||
|
// online softmax / attention
|
||||||
|
// loop over n_kv and n_head_kv
|
||||||
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||||
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
|
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||||
|
if (mv == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float s;
|
||||||
|
|
||||||
|
// convert Q to F16 in V32
|
||||||
|
{
|
||||||
|
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||||
|
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_dot_f16(D,
|
||||||
|
&s, 0,
|
||||||
|
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
||||||
|
Q16, 0, 1);
|
||||||
|
|
||||||
|
s = s*scale + mv;
|
||||||
|
|
||||||
|
const float Mold = M;
|
||||||
|
|
||||||
|
float ms = 1.0f;
|
||||||
|
float vs = 1.0f;
|
||||||
|
|
||||||
|
if (s > M) {
|
||||||
|
M = s;
|
||||||
|
ms = expf(Mold - M);
|
||||||
|
|
||||||
|
// V = V*expf(Mold - M)
|
||||||
|
ggml_vec_scale_f16(D, V16, ms);
|
||||||
|
} else {
|
||||||
|
vs = expf(s - M);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||||
|
|
||||||
|
// V += v*expf(s - M)
|
||||||
|
ggml_vec_mad_f16(D, V16, v16, vs);
|
||||||
|
|
||||||
|
S = S*ms + vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// V /= S
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
|
||||||
|
}
|
||||||
|
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
// original
|
||||||
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||||
|
|
||||||
|
// permute(0, 2, 1, 3)
|
||||||
|
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_ext(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * mask,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (dst->op_params[1]) {
|
||||||
|
case GGML_PREC_DEFAULT:
|
||||||
|
case GGML_PREC_F32:
|
||||||
|
{
|
||||||
|
// uses F32 accumulators
|
||||||
|
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_ff
|
// ggml_compute_forward_flash_ff
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_ff_f16(
|
static void ggml_compute_forward_flash_ff_f16(
|
||||||
|
@ -16377,6 +16710,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
const bool masked = t != 0;
|
const bool masked = t != 0;
|
||||||
ggml_compute_forward_flash_attn(params, masked, tensor);
|
ggml_compute_forward_flash_attn(params, masked, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_ff(params, tensor);
|
ggml_compute_forward_flash_ff(params, tensor);
|
||||||
|
@ -17389,6 +17726,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * flash_grad = NULL;
|
struct ggml_tensor * flash_grad = NULL;
|
||||||
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
||||||
|
@ -18161,6 +18499,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
@ -18564,6 +18903,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||||
|
|
||||||
|
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
if (node->src[1]->type == GGML_TYPE_F32) {
|
||||||
|
@ -20629,7 +20974,7 @@ static void gguf_free_kv(struct gguf_kv * kv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct gguf_context * gguf_init_empty(void) {
|
struct gguf_context * gguf_init_empty(void) {
|
||||||
struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
|
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
||||||
|
|
||||||
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
|
memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
|
||||||
ctx->header.version = GGUF_VERSION;
|
ctx->header.version = GGUF_VERSION;
|
||||||
|
@ -20674,7 +21019,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
|
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
|
|
||||||
struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
|
struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
|
||||||
|
|
||||||
// read the header
|
// read the header
|
||||||
{
|
{
|
||||||
|
@ -20727,9 +21072,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
|
|
||||||
// read the kv pairs
|
// read the kv pairs
|
||||||
{
|
{
|
||||||
ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
|
const uint64_t n_kv = ctx->header.n_kv;
|
||||||
|
|
||||||
for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
|
// header.n_kv will hold the actual value of pairs that were successfully read in the loop below
|
||||||
|
ctx->header.n_kv = 0;
|
||||||
|
ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < n_kv; ++i) {
|
||||||
struct gguf_kv * kv = &ctx->kv[i];
|
struct gguf_kv * kv = &ctx->kv[i];
|
||||||
|
|
||||||
//fprintf(stderr, "%s: reading kv %d\n", __func__, i);
|
//fprintf(stderr, "%s: reading kv %d\n", __func__, i);
|
||||||
|
@ -20786,7 +21135,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type));
|
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
|
||||||
|
|
||||||
ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
|
ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
|
||||||
} break;
|
} break;
|
||||||
|
@ -20800,7 +21149,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str));
|
kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
|
||||||
|
|
||||||
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
|
||||||
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
|
ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
|
||||||
|
@ -20816,6 +21165,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx->header.n_kv++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
|
@ -20828,7 +21179,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
|
|
||||||
// read the tensor infos
|
// read the tensor infos
|
||||||
{
|
{
|
||||||
ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
|
ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
|
||||||
|
|
||||||
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
|
||||||
struct gguf_tensor_info * info = &ctx->infos[i];
|
struct gguf_tensor_info * info = &ctx->infos[i];
|
||||||
|
@ -20855,8 +21206,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
|
||||||
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
|
||||||
|
|
||||||
|
// TODO: return an error instead of crashing with GGML_ASSERT
|
||||||
gguf_tensor_info_sanitize(info);
|
gguf_tensor_info_sanitize(info);
|
||||||
|
|
||||||
|
// make sure there is no duplicated tensor names
|
||||||
|
for (uint64_t j = 0; j < i; ++j) {
|
||||||
|
if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
|
||||||
|
fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
fprintf(stderr, "%s: failed to read tensor info\n", __func__);
|
||||||
fclose(file);
|
fclose(file);
|
||||||
|
@ -21025,7 +21385,7 @@ void gguf_free(struct gguf_context * ctx) {
|
||||||
GGML_FREE(ctx->infos);
|
GGML_FREE(ctx->infos);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ALIGNED_FREE(ctx);
|
GGML_FREE(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * gguf_type_name(enum gguf_type type) {
|
const char * gguf_type_name(enum gguf_type type) {
|
||||||
|
@ -21336,7 +21696,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty
|
||||||
ctx->kv[idx].type = GGUF_TYPE_ARRAY;
|
ctx->kv[idx].type = GGUF_TYPE_ARRAY;
|
||||||
ctx->kv[idx].value.arr.type = type;
|
ctx->kv[idx].value.arr.type = type;
|
||||||
ctx->kv[idx].value.arr.n = n;
|
ctx->kv[idx].value.arr.n = n;
|
||||||
ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type));
|
ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
|
||||||
memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
|
memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21346,7 +21706,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
|
||||||
ctx->kv[idx].type = GGUF_TYPE_ARRAY;
|
ctx->kv[idx].type = GGUF_TYPE_ARRAY;
|
||||||
ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
|
ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
|
||||||
ctx->kv[idx].value.arr.n = n;
|
ctx->kv[idx].value.arr.n = n;
|
||||||
ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str));
|
ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
|
struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
|
||||||
str->n = strlen(data[i]);
|
str->n = strlen(data[i]);
|
||||||
|
@ -21373,7 +21733,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
|
||||||
case GGUF_TYPE_ARRAY:
|
case GGUF_TYPE_ARRAY:
|
||||||
{
|
{
|
||||||
if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
|
if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
|
||||||
const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *));
|
const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
|
||||||
for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
|
for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
|
||||||
data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
|
data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
|
||||||
}
|
}
|
||||||
|
@ -21393,6 +21753,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
|
||||||
void gguf_add_tensor(
|
void gguf_add_tensor(
|
||||||
struct gguf_context * ctx,
|
struct gguf_context * ctx,
|
||||||
const struct ggml_tensor * tensor) {
|
const struct ggml_tensor * tensor) {
|
||||||
|
if (gguf_find_tensor(ctx, tensor->name) != -1) {
|
||||||
|
GGML_ASSERT(false && "duplicated tensor name");
|
||||||
|
}
|
||||||
|
|
||||||
const int idx = ctx->header.n_tensors;
|
const int idx = ctx->header.n_tensors;
|
||||||
ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
|
ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
|
||||||
|
|
||||||
|
@ -21461,7 +21825,7 @@ struct gguf_buf {
|
||||||
|
|
||||||
static struct gguf_buf gguf_buf_init(size_t size) {
|
static struct gguf_buf gguf_buf_init(size_t size) {
|
||||||
struct gguf_buf buf = {
|
struct gguf_buf buf = {
|
||||||
/*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size),
|
/*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
|
||||||
/*buf.size =*/ size,
|
/*buf.size =*/ size,
|
||||||
/*buf.offset =*/ 0,
|
/*buf.offset =*/ 0,
|
||||||
};
|
};
|
||||||
|
|
22
ggml.h
22
ggml.h
|
@ -482,6 +482,7 @@ extern "C" {
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
|
GGML_OP_FLASH_ATTN_EXT,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
GGML_OP_SSM_CONV,
|
GGML_OP_SSM_CONV,
|
||||||
|
@ -769,6 +770,8 @@ extern "C" {
|
||||||
// use this to compute the memory overhead of a tensor
|
// use this to compute the memory overhead of a tensor
|
||||||
GGML_API size_t ggml_tensor_overhead(void);
|
GGML_API size_t ggml_tensor_overhead(void);
|
||||||
|
|
||||||
|
GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
|
||||||
|
|
||||||
// main
|
// main
|
||||||
|
|
||||||
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
|
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
|
||||||
|
@ -1727,6 +1730,25 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
|
#define GGML_KQ_MASK_PAD 32
|
||||||
|
|
||||||
|
// q: [n_embd, n_batch, n_head, 1]
|
||||||
|
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||||
|
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
||||||
|
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||||
|
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
||||||
|
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale);
|
||||||
|
|
||||||
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
|
|
@ -72,6 +72,7 @@ class Keys:
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
MODEL = "tokenizer.ggml.model"
|
MODEL = "tokenizer.ggml.model"
|
||||||
|
PRE = "tokenizer.ggml.pre"
|
||||||
LIST = "tokenizer.ggml.tokens"
|
LIST = "tokenizer.ggml.tokens"
|
||||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||||
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
|
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
|
||||||
|
@ -940,6 +941,7 @@ KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||||
|
|
||||||
# tokenization
|
# tokenization
|
||||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||||
|
KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
|
||||||
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
||||||
KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
|
KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
|
||||||
KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
|
KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
|
||||||
|
|
|
@ -139,7 +139,12 @@ class GGUFReader:
|
||||||
|
|
||||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||||
if field.name in self.fields:
|
if field.name in self.fields:
|
||||||
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
# TODO: add option to generate error on duplicate keys
|
||||||
|
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
|
||||||
|
|
||||||
|
print(f'Warning: Duplicate key {field.name} at offset {field.offset}')
|
||||||
|
self.fields[field.name + '_{}'.format(field.offset)] = field
|
||||||
|
else:
|
||||||
self.fields[field.name] = field
|
self.fields[field.name] = field
|
||||||
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
|
||||||
|
|
||||||
|
@ -234,8 +239,14 @@ class GGUFReader:
|
||||||
|
|
||||||
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
|
def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
|
||||||
tensors = []
|
tensors = []
|
||||||
|
tensor_names = set() # keep track of name to prevent duplicated tensors
|
||||||
for field in fields:
|
for field in fields:
|
||||||
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
|
_name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
|
||||||
|
# check if there's any tensor having same name already in the list
|
||||||
|
tensor_name = str(bytes(name_data), encoding = 'utf-8')
|
||||||
|
if tensor_name in tensor_names:
|
||||||
|
raise ValueError(f'Found duplicated tensor with name {tensor_name}')
|
||||||
|
tensor_names.add(tensor_name)
|
||||||
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
ggml_type = GGMLQuantizationType(raw_dtype[0])
|
||||||
n_elems = np.prod(dims)
|
n_elems = np.prod(dims)
|
||||||
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
|
||||||
|
@ -267,7 +278,7 @@ class GGUFReader:
|
||||||
item_count = n_bytes
|
item_count = n_bytes
|
||||||
item_type = np.uint8
|
item_type = np.uint8
|
||||||
tensors.append(ReaderTensor(
|
tensors.append(ReaderTensor(
|
||||||
name = str(bytes(name_data), encoding = 'utf-8'),
|
name = tensor_name,
|
||||||
tensor_type = ggml_type,
|
tensor_type = ggml_type,
|
||||||
shape = dims,
|
shape = dims,
|
||||||
n_elements = n_elems,
|
n_elements = n_elems,
|
||||||
|
|
|
@ -63,6 +63,7 @@ class GGUFWriter:
|
||||||
self.kv_data_count = 0
|
self.kv_data_count = 0
|
||||||
self.ti_data = bytearray()
|
self.ti_data = bytearray()
|
||||||
self.ti_data_count = 0
|
self.ti_data_count = 0
|
||||||
|
self.ti_names = set()
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.temp_file = None
|
self.temp_file = None
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
|
@ -197,6 +198,10 @@ class GGUFWriter:
|
||||||
if self.state is not WriterState.EMPTY:
|
if self.state is not WriterState.EMPTY:
|
||||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||||
|
|
||||||
|
if name in self.ti_names:
|
||||||
|
raise ValueError(f'Duplicated tensor name {name}')
|
||||||
|
self.ti_names.add(name)
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf8")
|
||||||
self.ti_data += self._pack("Q", len(encoded_name))
|
self.ti_data += self._pack("Q", len(encoded_name))
|
||||||
self.ti_data += encoded_name
|
self.ti_data += encoded_name
|
||||||
|
@ -422,6 +427,9 @@ class GGUFWriter:
|
||||||
def add_tokenizer_model(self, model: str) -> None:
|
def add_tokenizer_model(self, model: str) -> None:
|
||||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||||
|
|
||||||
|
def add_tokenizer_pre(self, pre: str) -> None:
|
||||||
|
self.add_string(Keys.Tokenizer.PRE, pre)
|
||||||
|
|
||||||
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
|
||||||
self.add_array(Keys.Tokenizer.LIST, tokens)
|
self.add_array(Keys.Tokenizer.LIST, tokens)
|
||||||
|
|
||||||
|
|
|
@ -177,7 +177,7 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true, true);
|
output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, add_bos, true);
|
||||||
if(add_bos)
|
if(add_bos)
|
||||||
{
|
{
|
||||||
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
|
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
|
||||||
|
@ -256,6 +256,15 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab)
|
||||||
}
|
}
|
||||||
return eosID;
|
return eosID;
|
||||||
}
|
}
|
||||||
|
static int GetEotID(FileFormat file_format)
|
||||||
|
{
|
||||||
|
if(file_format == FileFormat::GGUF_GENERIC)
|
||||||
|
{
|
||||||
|
return llama_token_eot(&(llama_ctx_v4->model));
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
static float LowestLogit(const std::vector<float> & logits)
|
static float LowestLogit(const std::vector<float> & logits)
|
||||||
{
|
{
|
||||||
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||||
|
@ -484,6 +493,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token eos = GetEosID(file_format,n_vocab);
|
const llama_token eos = GetEosID(file_format,n_vocab);
|
||||||
|
const llama_token eot = GetEotID(file_format);
|
||||||
|
|
||||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||||
std::vector<llama_grammar_candidate> candidates_grammar;
|
std::vector<llama_grammar_candidate> candidates_grammar;
|
||||||
|
@ -491,7 +501,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = candidates->data[i].id;
|
||||||
const std::string piece = FileFormatTokenizeID(id,file_format);
|
const std::string piece = FileFormatTokenizeID(id,file_format);
|
||||||
if (id == eos) {
|
if (id == eos || (id==eot && id!=-1)) {
|
||||||
if (!allow_eos) {
|
if (!allow_eos) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -602,7 +612,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
|
||||||
|
|
||||||
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
|
static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct llama_grammar * grammar, llama_token token)
|
||||||
{
|
{
|
||||||
if (token == GetEosID(file_format,n_vocab)) {
|
if (token == GetEosID(file_format,n_vocab) || (token!=-1 && token == GetEotID(file_format))) {
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar->stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return;
|
return;
|
||||||
|
@ -1601,12 +1611,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
//if it tokenizes to a single token, AND it's a single non-printable special token, use that
|
//if it tokenizes to a single token, AND it's a single non-printable special token, use that
|
||||||
std::vector<int> tmp;
|
std::vector<int> tmp;
|
||||||
TokenizeString(stopper, tmp, file_format, false);
|
TokenizeString(stopper, tmp, file_format, false);
|
||||||
|
printf("\nPRINT TOK VEC:");
|
||||||
|
print_tok_vec_str(tmp);
|
||||||
if(tmp.size()==1) //tokenizes to exactly 1 special token
|
if(tmp.size()==1) //tokenizes to exactly 1 special token
|
||||||
{
|
{
|
||||||
int specialid = tmp[0];
|
int specialid = tmp[0];
|
||||||
std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format);
|
std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format);
|
||||||
|
printf("\nTest %s",tokenizedstr.c_str());
|
||||||
if(tokenizedstr=="") //must NOT have a text representation
|
if(tokenizedstr=="") //must NOT have a text representation
|
||||||
{
|
{
|
||||||
|
printf("\nAdded %d",specialid);
|
||||||
special_stop_sequence.push_back(specialid);
|
special_stop_sequence.push_back(specialid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2167,6 +2181,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int eosID = GetEosID(file_format, n_vocab);
|
unsigned int eosID = GetEosID(file_format, n_vocab);
|
||||||
|
unsigned int eotID = GetEotID(file_format);
|
||||||
float * logitsPtr;
|
float * logitsPtr;
|
||||||
float lowestLogit = 0;
|
float lowestLogit = 0;
|
||||||
int btsize = banned_token_ids.size();
|
int btsize = banned_token_ids.size();
|
||||||
|
@ -2196,6 +2211,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
{
|
{
|
||||||
// set the logit of the eos token to very low to avoid sampling it
|
// set the logit of the eos token to very low to avoid sampling it
|
||||||
logitsPtr[eosID] = lowestLogit;
|
logitsPtr[eosID] = lowestLogit;
|
||||||
|
if(eotID!=-1)
|
||||||
|
{
|
||||||
|
logitsPtr[eotID] = lowestLogit;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(btsize>0)
|
if(btsize>0)
|
||||||
{
|
{
|
||||||
|
@ -2257,7 +2276,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
||||||
printf("]\n");
|
printf("]\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inputs.allow_eos_token && id==eosID)
|
if(inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1)))
|
||||||
{
|
{
|
||||||
stopper_unused_tokens = remaining_tokens;
|
stopper_unused_tokens = remaining_tokens;
|
||||||
if(allow_regular_prints)
|
if(allow_regular_prints)
|
||||||
|
|
|
@ -8791,8 +8791,8 @@ Current version: 136
|
||||||
et = "<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
et = "<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
|
||||||
break;
|
break;
|
||||||
case "9": //llama 3 chat
|
case "9": //llama 3 chat
|
||||||
st = "<|eot_id|><|start_header_id|>user<|end_header_id|>";
|
st = "<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n";
|
||||||
et = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>";
|
et = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n";
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
30
llama.h
30
llama.h
|
@ -40,7 +40,7 @@
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 5
|
#define LLAMA_SESSION_VERSION 6
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 1
|
#define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
@ -69,6 +69,18 @@ extern "C" {
|
||||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// pre-tokenization types
|
||||||
|
enum llama_vocab_pre_type {
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
||||||
|
};
|
||||||
|
|
||||||
// note: these values should be synchronized with ggml_rope
|
// note: these values should be synchronized with ggml_rope
|
||||||
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
||||||
enum llama_rope_type {
|
enum llama_rope_type {
|
||||||
|
@ -195,15 +207,19 @@ extern "C" {
|
||||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||||
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
||||||
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
LLAMA_KV_OVERRIDE_TYPE_BOOL,
|
||||||
|
LLAMA_KV_OVERRIDE_TYPE_STR,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_model_kv_override {
|
struct llama_model_kv_override {
|
||||||
char key[128];
|
|
||||||
enum llama_model_kv_override_type tag;
|
enum llama_model_kv_override_type tag;
|
||||||
|
|
||||||
|
char key[128];
|
||||||
|
|
||||||
union {
|
union {
|
||||||
int64_t int_value;
|
int64_t val_i64;
|
||||||
double float_value;
|
double val_f64;
|
||||||
bool bool_value;
|
bool val_bool;
|
||||||
|
char val_str[128];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -235,6 +251,7 @@ extern "C" {
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
bool use_mmap; // use mmap if possible
|
bool use_mmap; // use mmap if possible
|
||||||
bool use_mlock; // force system to keep model in RAM
|
bool use_mlock; // force system to keep model in RAM
|
||||||
|
bool check_tensors; // validate model tensor data
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
|
@ -270,6 +287,7 @@ extern "C" {
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
|
bool flash_attn; // whether to use flash attention
|
||||||
|
|
||||||
// Abort callback
|
// Abort callback
|
||||||
// if it returns true, execution of llama_decode() will be aborted
|
// if it returns true, execution of llama_decode() will be aborted
|
||||||
|
@ -525,7 +543,7 @@ extern "C" {
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Clear the KV cache
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API void llama_kv_cache_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
|
|
170
sgemm.cpp
170
sgemm.cpp
|
@ -50,7 +50,6 @@
|
||||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||||
|
|
||||||
#include "sgemm.h"
|
#include "sgemm.h"
|
||||||
#include <algorithm>
|
|
||||||
#include "ggml-impl.h"
|
#include "ggml-impl.h"
|
||||||
#include "ggml-quants.h"
|
#include "ggml-quants.h"
|
||||||
|
|
||||||
|
@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
||||||
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
||||||
class tinyBLAS {
|
class tinyBLAS {
|
||||||
public:
|
public:
|
||||||
tinyBLAS(int k,
|
tinyBLAS(int64_t k,
|
||||||
const TA *A, int lda,
|
const TA *A, int64_t lda,
|
||||||
const TB *B, int ldb,
|
const TB *B, int64_t ldb,
|
||||||
TC *C, int ldc,
|
TC *C, int64_t ldc,
|
||||||
int ith, int nth)
|
int ith, int nth)
|
||||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void matmul(int m, int n, int task) {
|
void matmul(int64_t m, int64_t n, int task) {
|
||||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||||
mnpack(0, m, 0, n);
|
mnpack(0, m, 0, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NOINLINE void mnpack(int m0, int m, int n0, int n) {
|
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int mc, nc, mp, np;
|
int64_t mc, nc, mp, np;
|
||||||
switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) {
|
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
||||||
#if VECTOR_REGISTERS == 32
|
#if VECTOR_REGISTERS == 32
|
||||||
case 0x55:
|
case 0x55:
|
||||||
mc = 5;
|
mc = 5;
|
||||||
|
@ -409,27 +408,27 @@ class tinyBLAS {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int RM, int RN>
|
template <int RM, int RN>
|
||||||
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int ytiles = (m - m0) / RM;
|
int64_t ytiles = (m - m0) / RM;
|
||||||
int xtiles = (n - n0) / RN;
|
int64_t xtiles = (n - n0) / RN;
|
||||||
int tiles = xtiles * ytiles;
|
int64_t tiles = xtiles * ytiles;
|
||||||
int duty = (tiles + nth - 1) / nth;
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
int start = duty * ith;
|
int64_t start = duty * ith;
|
||||||
int end = start + duty;
|
int64_t end = start + duty;
|
||||||
if (end > tiles)
|
if (end > tiles)
|
||||||
end = tiles;
|
end = tiles;
|
||||||
for (int job = start; job < end; ++job) {
|
for (int64_t job = start; job < end; ++job) {
|
||||||
int ii = m0 + job / xtiles * RM;
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
int jj = n0 + job % xtiles * RN;
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
D Cv[RN][RM] = {};
|
D Cv[RN][RM] = {};
|
||||||
for (int l = 0; l < k; l += KN)
|
for (int64_t l = 0; l < k; l += KN)
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
||||||
load<V>(B + ldb * (jj + j) + l),
|
load<V>(B + ldb * (jj + j) + l),
|
||||||
Cv[j][i]);
|
Cv[j][i]);
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -437,10 +436,10 @@ class tinyBLAS {
|
||||||
const TA *const A;
|
const TA *const A;
|
||||||
const TB *const B;
|
const TB *const B;
|
||||||
TC *const C;
|
TC *const C;
|
||||||
const int k;
|
const int64_t k;
|
||||||
const int lda;
|
const int64_t lda;
|
||||||
const int ldb;
|
const int64_t ldb;
|
||||||
const int ldc;
|
const int64_t ldc;
|
||||||
const int ith;
|
const int ith;
|
||||||
const int nth;
|
const int nth;
|
||||||
};
|
};
|
||||||
|
@ -452,23 +451,23 @@ class tinyBLAS {
|
||||||
template <typename TA>
|
template <typename TA>
|
||||||
class tinyBLAS_Q0_ARM {
|
class tinyBLAS_Q0_ARM {
|
||||||
public:
|
public:
|
||||||
tinyBLAS_Q0_ARM(int k,
|
tinyBLAS_Q0_ARM(int64_t k,
|
||||||
const TA *A, int lda,
|
const TA *A, int64_t lda,
|
||||||
const block_q8_0 *B, int ldb,
|
const block_q8_0 *B, int64_t ldb,
|
||||||
float *C, int ldc,
|
float *C, int64_t ldc,
|
||||||
int ith, int nth)
|
int ith, int nth)
|
||||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void matmul(int m, int n, int task) {
|
void matmul(int64_t m, int64_t n, int task) {
|
||||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||||
mnpack(0, m, 0, n);
|
mnpack(0, m, 0, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NOINLINE void mnpack(int m0, int m, int n0, int n) {
|
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int mc, nc, mp, np;
|
int64_t mc, nc, mp, np;
|
||||||
switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) {
|
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
|
||||||
case 0x33:
|
case 0x33:
|
||||||
mc = 3;
|
mc = 3;
|
||||||
nc = 3;
|
nc = 3;
|
||||||
|
@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int RM, int RN>
|
template <int RM, int RN>
|
||||||
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int ytiles = (m - m0) / RM;
|
int64_t ytiles = (m - m0) / RM;
|
||||||
int xtiles = (n - n0) / RN;
|
int64_t xtiles = (n - n0) / RN;
|
||||||
int tiles = xtiles * ytiles;
|
int64_t tiles = xtiles * ytiles;
|
||||||
int duty = (tiles + nth - 1) / nth;
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
int start = duty * ith;
|
int64_t start = duty * ith;
|
||||||
int end = start + duty;
|
int64_t end = start + duty;
|
||||||
if (end > tiles)
|
if (end > tiles)
|
||||||
end = tiles;
|
end = tiles;
|
||||||
for (int job = start; job < end; ++job) {
|
for (int64_t job = start; job < end; ++job) {
|
||||||
int ii = m0 + job / xtiles * RM;
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
int jj = n0 + job % xtiles * RN;
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
float32x4_t Cv[RN][RM] = {};
|
float32x4_t Cv[RN][RM] = {};
|
||||||
for (int l = 0; l < k; ++l)
|
for (int64_t l = 0; l < k; ++l)
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
||||||
vcvtq_f32_s32(vdotq_s32(
|
vcvtq_f32_s32(vdotq_s32(
|
||||||
vdotq_s32(vdupq_n_s32(0),
|
vdotq_s32(vdupq_n_s32(0),
|
||||||
|
@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM {
|
||||||
load_hi(B + ldb * (jj + j) + l))),
|
load_hi(B + ldb * (jj + j) + l))),
|
||||||
unhalf(A[lda * (ii + i) + l].d) *
|
unhalf(A[lda * (ii + i) + l].d) *
|
||||||
unhalf(B[ldb * (jj + j) + l].d));
|
unhalf(B[ldb * (jj + j) + l].d));
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM {
|
||||||
const TA *const A;
|
const TA *const A;
|
||||||
const block_q8_0 *const B;
|
const block_q8_0 *const B;
|
||||||
float *const C;
|
float *const C;
|
||||||
const int k;
|
const int64_t k;
|
||||||
const int lda;
|
const int64_t lda;
|
||||||
const int ldb;
|
const int64_t ldb;
|
||||||
const int ldc;
|
const int64_t ldc;
|
||||||
const int ith;
|
const int ith;
|
||||||
const int nth;
|
const int nth;
|
||||||
};
|
};
|
||||||
|
@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
class tinyBLAS_Q0_AVX2 {
|
class tinyBLAS_Q0_AVX2 {
|
||||||
public:
|
public:
|
||||||
tinyBLAS_Q0_AVX2(int k,
|
tinyBLAS_Q0_AVX2(int64_t k,
|
||||||
const TA *A, int lda,
|
const TA *A, int64_t lda,
|
||||||
const TB *B, int ldb,
|
const TB *B, int64_t ldb,
|
||||||
TC *C, int ldc,
|
TC *C, int64_t ldc,
|
||||||
int ith, int nth)
|
int ith, int nth)
|
||||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void matmul(int m, int n, int task) {
|
void matmul(int64_t m, int64_t n, int task) {
|
||||||
if (task == GGML_TASK_TYPE_COMPUTE)
|
if (task == GGML_TASK_TYPE_COMPUTE)
|
||||||
mnpack(0, m, 0, n);
|
mnpack(0, m, 0, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void mnpack(int m0, int m, int n0, int n) {
|
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int mc, nc, mp, np;
|
int64_t mc, nc, mp, np;
|
||||||
switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) {
|
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
|
||||||
#if VECTOR_REGISTERS == 32
|
#if VECTOR_REGISTERS == 32
|
||||||
case 0x44:
|
case 0x44:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
|
@ -714,22 +713,22 @@ class tinyBLAS_Q0_AVX2 {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int RM, int RN>
|
template <int RM, int RN>
|
||||||
NOINLINE void gemm(int m0, int m, int n0, int n) {
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int ytiles = (m - m0) / RM;
|
int64_t ytiles = (m - m0) / RM;
|
||||||
int xtiles = (n - n0) / RN;
|
int64_t xtiles = (n - n0) / RN;
|
||||||
int tiles = xtiles * ytiles;
|
int64_t tiles = xtiles * ytiles;
|
||||||
int duty = (tiles + nth - 1) / nth;
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
int start = duty * ith;
|
int64_t start = duty * ith;
|
||||||
int end = start + duty;
|
int64_t end = start + duty;
|
||||||
if (end > tiles)
|
if (end > tiles)
|
||||||
end = tiles;
|
end = tiles;
|
||||||
for (int job = start; job < end; ++job) {
|
for (int64_t job = start; job < end; ++job) {
|
||||||
int ii = m0 + job / xtiles * RM;
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
int jj = n0 + job % xtiles * RN;
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
__m256 Cv[RN][RM] = {};
|
__m256 Cv[RN][RM] = {};
|
||||||
for (int l = 0; l < k; ++l)
|
for (int64_t l = 0; l < k; ++l)
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
||||||
unhalf(B[ldb * (jj + j) + l].d)),
|
unhalf(B[ldb * (jj + j) + l].d)),
|
||||||
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||||
|
@ -737,8 +736,8 @@ class tinyBLAS_Q0_AVX2 {
|
||||||
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
||||||
load(A + lda * (ii + i) + l))),
|
load(A + lda * (ii + i) + l))),
|
||||||
Cv[j][i]);
|
Cv[j][i]);
|
||||||
for (int j = 0; j < RN; ++j)
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
for (int i = 0; i < RM; ++i)
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 {
|
||||||
const TA *const A;
|
const TA *const A;
|
||||||
const TB *const B;
|
const TB *const B;
|
||||||
TC *const C;
|
TC *const C;
|
||||||
const int k;
|
const int64_t k;
|
||||||
const int lda;
|
const int64_t lda;
|
||||||
const int ldb;
|
const int64_t ldb;
|
||||||
const int ldc;
|
const int64_t ldc;
|
||||||
const int ith;
|
const int ith;
|
||||||
const int nth;
|
const int nth;
|
||||||
};
|
};
|
||||||
|
@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 {
|
||||||
* @param Ctype is GGML data type of `C`
|
* @param Ctype is GGML data type of `C`
|
||||||
* @return true if this function was able to service the matmul request
|
* @return true if this function was able to service the matmul request
|
||||||
*/
|
*/
|
||||||
bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C,
|
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
||||||
int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
|
||||||
|
|
||||||
assert(m >= 0);
|
assert(m >= 0);
|
||||||
assert(n >= 0);
|
assert(n >= 0);
|
||||||
|
@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
|
||||||
assert(ldc >= m);
|
assert(ldc >= m);
|
||||||
assert(nth > 0);
|
assert(nth > 0);
|
||||||
assert(ith < nth);
|
assert(ith < nth);
|
||||||
assert(1ll * lda * m <= 0x7fffffff);
|
|
||||||
assert(1ll * ldb * n <= 0x7fffffff);
|
|
||||||
assert(1ll * ldc * n <= 0x7fffffff);
|
|
||||||
|
|
||||||
if (Ctype != GGML_TYPE_F32)
|
if (Ctype != GGML_TYPE_F32)
|
||||||
return false;
|
return false;
|
||||||
|
|
6
sgemm.h
6
sgemm.h
|
@ -1,11 +1,13 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include <stdint.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool llamafile_sgemm(int, int, int, const void *, int, const void *, int,
|
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
||||||
void *, int, int, int, int, int, int, int);
|
const void *, int64_t, void *, int64_t, int, int,
|
||||||
|
int, int, int, int);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#include "unicode-data.h"
|
#include "unicode-data.h"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
487
unicode.cpp
487
unicode.cpp
|
@ -5,11 +5,14 @@
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <regex>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <locale>
|
||||||
|
#include <codecvt>
|
||||||
|
|
||||||
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
@ -53,23 +56,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
||||||
offset += 4;
|
offset += 4;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
throw std::invalid_argument("invalid string");
|
throw std::invalid_argument("failed to convert utf8 to codepoint");
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
||||||
std::vector<uint16_t> result;
|
// std::vector<uint16_t> result;
|
||||||
if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
|
||||||
result.emplace_back(cp);
|
// result.emplace_back(cp);
|
||||||
}
|
// return result;
|
||||||
else if (0x10000 <= cp && cp <= 0x10ffff) {
|
// }
|
||||||
result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
|
// if (0x10000 <= cp && cp <= 0x10ffff) {
|
||||||
result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
|
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
|
||||||
}
|
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
|
||||||
else {
|
// return result;
|
||||||
throw std::invalid_argument("invalid cpt");
|
// }
|
||||||
}
|
// throw std::invalid_argument("failed to convert codepoint to utf16");
|
||||||
return result;
|
//}
|
||||||
}
|
|
||||||
|
|
||||||
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
|
//static std::vector<uint16_t> unicode_cpts_to_utf16(const std::vector<uint32_t> & cps) {
|
||||||
// std::vector<uint16_t> result;
|
// std::vector<uint16_t> result;
|
||||||
|
@ -80,28 +82,28 @@ static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
|
||||||
// return result;
|
// return result;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
static uint32_t cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
|
//static uint32_t unicode_cpt_from_utf16(const std::vector<uint16_t> & utf16, size_t & offset) {
|
||||||
assert(offset < utf16.size());
|
// assert(offset < utf16.size());
|
||||||
if (((utf16[0] >> 10) << 10) != 0xd800) {
|
// if (((utf16[0] >> 10) << 10) != 0xd800) {
|
||||||
auto result = utf16[offset + 0];
|
// auto result = utf16[offset + 0];
|
||||||
offset += 1;
|
// offset += 1;
|
||||||
return result;
|
// return result;
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
|
// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
|
||||||
throw std::invalid_argument("invalid character");
|
// throw std::invalid_argument("invalid character");
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
|
// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
|
||||||
offset += 2;
|
// offset += 2;
|
||||||
return result;
|
// return result;
|
||||||
}
|
//}
|
||||||
|
|
||||||
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
|
//static std::vector<uint32_t> unicode_cpts_from_utf16(const std::vector<uint16_t> & utf16) {
|
||||||
// std::vector<uint32_t> result;
|
// std::vector<uint32_t> result;
|
||||||
// size_t offset = 0;
|
// size_t offset = 0;
|
||||||
// while (offset < utf16.size()) {
|
// while (offset < utf16.size()) {
|
||||||
// result.push_back(cpt_from_utf16(utf16, offset));
|
// result.push_back(unicode_cpt_from_utf16(utf16, offset));
|
||||||
// }
|
// }
|
||||||
// return result;
|
// return result;
|
||||||
//}
|
//}
|
||||||
|
@ -194,36 +196,279 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
||||||
|
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
||||||
|
return conv.from_bytes(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) {
|
||||||
|
std::vector<std::string> bpe_encoded_words;
|
||||||
|
for (const auto & word : bpe_words) {
|
||||||
|
std::string text_utf;
|
||||||
|
auto utf_word = unicode_cpts_from_utf8(word);
|
||||||
|
for (size_t i = 0; i < utf_word.size(); ++i) {
|
||||||
|
text_utf += unicode_cpt_to_utf8(utf_word[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string encoded_token;
|
||||||
|
for (char & c : text_utf) {
|
||||||
|
encoded_token += unicode_byte_to_utf8(c);
|
||||||
|
}
|
||||||
|
bpe_encoded_words.emplace_back(encoded_token);
|
||||||
|
}
|
||||||
|
return bpe_encoded_words;
|
||||||
|
}
|
||||||
|
|
||||||
|
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||||
|
static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & text, const std::vector<size_t> & offsets) {
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
|
||||||
|
size_t start = 0;
|
||||||
|
|
||||||
|
const auto cpts = unicode_cpts_from_utf8(text);
|
||||||
|
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
std::string token;
|
||||||
|
|
||||||
|
bool collecting_numeric = false;
|
||||||
|
bool collecting_letter = false;
|
||||||
|
bool collecting_special = false;
|
||||||
|
bool collecting_whitespace_lookahead = false;
|
||||||
|
bool collecting = false;
|
||||||
|
|
||||||
|
std::vector<std::string> text_utf;
|
||||||
|
text_utf.reserve(offset);
|
||||||
|
|
||||||
|
for (size_t i = start; i < start + offset; ++i) {
|
||||||
|
text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < (int)text_utf.size(); i++) {
|
||||||
|
const std::string & utf_char = text_utf[i];
|
||||||
|
bool split_condition = false;
|
||||||
|
int bytes_remain = text_utf.size() - i;
|
||||||
|
|
||||||
|
// forward backward lookups
|
||||||
|
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
||||||
|
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
||||||
|
|
||||||
|
// handling contractions
|
||||||
|
if (!split_condition && bytes_remain >= 2) {
|
||||||
|
// 's|'t|'m|'d
|
||||||
|
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
if (split_condition) {
|
||||||
|
if (token.size()) {
|
||||||
|
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||||
|
}
|
||||||
|
token = utf_char + utf_char_next;
|
||||||
|
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||||
|
token = "";
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!split_condition && bytes_remain >= 3) {
|
||||||
|
// 're|'ve|'ll
|
||||||
|
if (utf_char == "\'" && (
|
||||||
|
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
||||||
|
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
||||||
|
(utf_char_next == "l" && utf_char_next_next == "l"))
|
||||||
|
) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
if (split_condition) {
|
||||||
|
// current token + next token can be defined
|
||||||
|
if (token.size()) {
|
||||||
|
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||||
|
}
|
||||||
|
token = utf_char;
|
||||||
|
token += utf_char_next;
|
||||||
|
token += utf_char_next_next;
|
||||||
|
|
||||||
|
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||||
|
token = "";
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!split_condition && !collecting) {
|
||||||
|
if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||||
|
collecting_letter = true;
|
||||||
|
collecting = true;
|
||||||
|
}
|
||||||
|
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||||
|
collecting_numeric = true;
|
||||||
|
collecting = true;
|
||||||
|
}
|
||||||
|
else if (
|
||||||
|
((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||||
|
(token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||||
|
) {
|
||||||
|
collecting_special = true;
|
||||||
|
collecting = true;
|
||||||
|
}
|
||||||
|
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||||
|
collecting_whitespace_lookahead = true;
|
||||||
|
collecting = true;
|
||||||
|
}
|
||||||
|
else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (!split_condition && collecting) {
|
||||||
|
if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||||
|
split_condition = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (utf_char_next == "") {
|
||||||
|
split_condition = true; // final
|
||||||
|
token += utf_char;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (split_condition) {
|
||||||
|
if (token.size()) {
|
||||||
|
bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
|
||||||
|
}
|
||||||
|
token = utf_char;
|
||||||
|
collecting = false;
|
||||||
|
collecting_letter = false;
|
||||||
|
collecting_numeric = false;
|
||||||
|
collecting_special = false;
|
||||||
|
collecting_whitespace_lookahead = false;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
token += utf_char;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use std::wregex to split the text
|
||||||
|
static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
|
std::wregex expr(regex_expr);
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
size_t start = 0;
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
|
||||||
|
std::wcregex_iterator end;
|
||||||
|
|
||||||
|
int64_t start_idx = 0;
|
||||||
|
while (it != end) {
|
||||||
|
std::wcmatch match = *it;
|
||||||
|
if (match.position() > start_idx) {
|
||||||
|
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||||
|
}
|
||||||
|
bpe_offsets.emplace_back(match.length());
|
||||||
|
start_idx = match.position() + match.length();
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start_idx < (int64_t) offset) {
|
||||||
|
bpe_offsets.emplace_back(offset - start_idx);
|
||||||
|
}
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use std::regex to split the text
|
||||||
|
static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
|
std::regex expr(regex_expr);
|
||||||
|
std::vector<size_t> bpe_offsets; // store the offset of each word
|
||||||
|
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
|
||||||
|
size_t start = 0;
|
||||||
|
for (auto offset : offsets) {
|
||||||
|
std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
|
||||||
|
std::cregex_iterator end;
|
||||||
|
|
||||||
|
int64_t start_idx = 0;
|
||||||
|
while (it != end) {
|
||||||
|
std::cmatch match = *it;
|
||||||
|
if (match.position() > start_idx) {
|
||||||
|
bpe_offsets.emplace_back(match.position() - start_idx);
|
||||||
|
}
|
||||||
|
bpe_offsets.emplace_back(match.length());
|
||||||
|
start_idx = match.position() + match.length();
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start_idx < (int64_t) offset) {
|
||||||
|
bpe_offsets.emplace_back(offset - start_idx);
|
||||||
|
}
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||||
|
std::vector<size_t> bpe_offsets;
|
||||||
|
|
||||||
|
(void)(text);
|
||||||
|
(void)(regex_expr);
|
||||||
|
(void)(offsets);
|
||||||
|
// TODO: this implementation is actually wrong, uncomment and run:
|
||||||
|
// make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf
|
||||||
|
//if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
|
||||||
|
// bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
|
||||||
|
//}
|
||||||
|
|
||||||
|
return bpe_offsets;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// interface
|
// interface
|
||||||
//
|
//
|
||||||
|
|
||||||
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
std::string unicode_cpt_to_utf8(uint32_t cp) {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
if (/* 0x00 <= cp && */ cp <= 0x7f) {
|
||||||
result.push_back(cp);
|
result.push_back(cp);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
else if (0x80 <= cp && cp <= 0x7ff) {
|
if (0x80 <= cp && cp <= 0x7ff) {
|
||||||
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cp & 0x3f));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
else if (0x800 <= cp && cp <= 0xffff) {
|
if (0x800 <= cp && cp <= 0xffff) {
|
||||||
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
|
||||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cp & 0x3f));
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
else if (0x10000 <= cp && cp <= 0x10ffff) {
|
if (0x10000 <= cp && cp <= 0x10ffff) {
|
||||||
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
result.push_back(0xf0 | ((cp >> 18) & 0x07));
|
||||||
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
result.push_back(0x80 | ((cp >> 12) & 0x3f));
|
||||||
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
result.push_back(0x80 | ((cp >> 6) & 0x3f));
|
||||||
result.push_back(0x80 | (cp & 0x3f));
|
result.push_back(0x80 | (cp & 0x3f));
|
||||||
}
|
|
||||||
else {
|
|
||||||
throw std::invalid_argument("invalid codepoint");
|
|
||||||
}
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
throw std::invalid_argument("invalid codepoint");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
||||||
std::vector<uint32_t> result;
|
std::vector<uint32_t> result;
|
||||||
result.reserve(cpts.size());
|
result.reserve(cpts.size());
|
||||||
|
@ -275,3 +520,167 @@ char32_t unicode_tolower(char32_t cp) {
|
||||||
auto it = unicode_map_lowercase.find(cp);
|
auto it = unicode_map_lowercase.find(cp);
|
||||||
return it == unicode_map_lowercase.end() ? cp : it->second;
|
return it == unicode_map_lowercase.end() ? cp : it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||||
|
// unicode categories
|
||||||
|
static const std::map<std::string, int> k_ucat_enum = {
|
||||||
|
{ "\\p{N}", CODEPOINT_TYPE_DIGIT },
|
||||||
|
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
|
||||||
|
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<int, int> k_ucat_cpt = {
|
||||||
|
{ CODEPOINT_TYPE_DIGIT, 0xD1 },
|
||||||
|
{ CODEPOINT_TYPE_LETTER, 0xD2 },
|
||||||
|
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<int, std::string> k_ucat_map = {
|
||||||
|
{ CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9
|
||||||
|
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||||
|
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||||
|
};
|
||||||
|
|
||||||
|
// compute collapsed codepoints only if needed by at least one regex
|
||||||
|
bool need_collapse = false;
|
||||||
|
for (auto & regex_expr : regex_exprs) {
|
||||||
|
// search for unicode categories
|
||||||
|
for (const auto & ucat : k_ucat_enum) {
|
||||||
|
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||||
|
need_collapse = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto cpts = unicode_cpts_from_utf8(text);
|
||||||
|
|
||||||
|
// generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
|
||||||
|
std::string text_collapsed;
|
||||||
|
if (need_collapse) {
|
||||||
|
// collapse all unicode categories
|
||||||
|
text_collapsed.resize(cpts.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||||
|
// keep single-byte codepoints as is
|
||||||
|
if (cpts[i] < 128) {
|
||||||
|
text_collapsed[i] = cpts[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int cpt_type = unicode_cpt_type(cpts[i]);
|
||||||
|
|
||||||
|
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
|
||||||
|
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
|
||||||
|
} else {
|
||||||
|
text_collapsed[i] = (char) 0xD0; // fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> bpe_offsets = { cpts.size() };
|
||||||
|
|
||||||
|
for (auto & regex_expr : regex_exprs) {
|
||||||
|
// first, see if we have an efficient custom regex implementation
|
||||||
|
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
|
||||||
|
|
||||||
|
if (!tmp.empty()) {
|
||||||
|
bpe_offsets = std::move(tmp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback to general-purpose std::regex / std::wregex
|
||||||
|
try {
|
||||||
|
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
|
||||||
|
// with the corresponding collapsed representation
|
||||||
|
bool use_collapsed = false;
|
||||||
|
for (auto & ucat : k_ucat_enum) {
|
||||||
|
if (std::string::npos != regex_expr.find(ucat.first)) {
|
||||||
|
use_collapsed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_collapsed) {
|
||||||
|
// sanity-check that the original regex does not contain any non-ASCII characters
|
||||||
|
const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
|
||||||
|
for (size_t i = 0; i < cpts_regex.size(); ++i) {
|
||||||
|
if (cpts_regex[i] >= 128) {
|
||||||
|
throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate a collapsed representation of the regex
|
||||||
|
std::string regex_expr_collapsed;
|
||||||
|
|
||||||
|
// track if we are inside [], because nested [] are not allowed
|
||||||
|
bool inside = false;
|
||||||
|
for (size_t i = 0; i < regex_expr.size(); ++i) {
|
||||||
|
if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) {
|
||||||
|
regex_expr_collapsed += '[';
|
||||||
|
inside = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') {
|
||||||
|
regex_expr_collapsed += ']';
|
||||||
|
inside = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
|
||||||
|
regex_expr[i + 1] == 'p' &&
|
||||||
|
regex_expr[i + 2] == '{' &&
|
||||||
|
regex_expr[i + 4] == '}') {
|
||||||
|
const std::string pat = regex_expr.substr(i, 5);
|
||||||
|
if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
|
||||||
|
if (!inside) {
|
||||||
|
regex_expr_collapsed += '[';
|
||||||
|
}
|
||||||
|
regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
|
||||||
|
regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
|
||||||
|
if (!inside) {
|
||||||
|
regex_expr_collapsed += ']';
|
||||||
|
}
|
||||||
|
i += 4;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
regex_expr_collapsed += regex_expr[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
//printf("text_collapsed: %s\n", text_collapsed.c_str());
|
||||||
|
//printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str());
|
||||||
|
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||||
|
} else {
|
||||||
|
// no unicode category used, we can use std::wregex directly
|
||||||
|
const std::wstring wtext = unicode_wstring_from_utf8(text);
|
||||||
|
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||||
|
|
||||||
|
//printf("text: %s\n", text.c_str());
|
||||||
|
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||||
|
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||||
|
}
|
||||||
|
} catch (std::regex_error & e) {
|
||||||
|
fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str());
|
||||||
|
fprintf(stderr, "Regex error: %s\n", e.what());
|
||||||
|
throw std::runtime_error("Failed to process regex");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> bpe_words;
|
||||||
|
bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size
|
||||||
|
|
||||||
|
size_t start = 0;
|
||||||
|
for (size_t & offset : bpe_offsets) {
|
||||||
|
bpe_words.emplace_back();
|
||||||
|
for (size_t i = start; i < start + offset; ++i) {
|
||||||
|
bpe_words.back() += unicode_cpt_to_utf8(cpts[i]);
|
||||||
|
}
|
||||||
|
start += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return unicode_byte_encoding_process(bpe_words);
|
||||||
|
}
|
||||||
|
|
|
@ -24,5 +24,6 @@ int unicode_cpt_type(const std::string & utf8);
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||||
|
|
||||||
// simple tolower that only implements one-to-one mapping, not one-to-many
|
|
||||||
char32_t unicode_tolower(char32_t cp);
|
char32_t unicode_tolower(char32_t cp);
|
||||||
|
|
||||||
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue