diff --git a/common/common.cpp b/common/common.cpp index 20130f20b..df078ba48 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -68,7 +68,6 @@ #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 -#define LLAMA_CURL_MAX_HEADER_LENGTH 256 #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -235,8 +234,54 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return result; } +bool parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { - llama_sampling_params& sparams = params.sparams; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { if (++i >= argc) { @@ -848,7 +893,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - params.image = argv[i]; + params.image.emplace_back(argv[i]); return true; } if (arg == "-i" || arg == "--interactive") { @@ -903,6 +948,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1090,6 +1139,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_print = std::stoi(argv[i]); return true; } + if (arg == "--check-tensors") { + params.check_tensors = true; + return true; + } if (arg == "--ppl-output-type") { if (++i >= argc) { invalid_param = true; @@ -1241,47 +1294,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - char* sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } - else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } - else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } - else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } - else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); - invalid_param = true; - return true; - } - } - else { + if (!parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; } - params.kv_overrides.push_back(kvo); return true; } #ifndef LOG_DISABLE_LOGS @@ -1311,6 +1328,29 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } +void gpt_params_handle_model_default(gpt_params & params) { + if (!params.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (params.hf_file.empty()) { + if (params.model.empty()) { + throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); + } + params.hf_file = params.model; + } else if (params.model.empty()) { + params.model = "models/" + string_split(params.hf_file, '/').back(); + } + } else if (!params.model_url.empty()) { + if (params.model.empty()) { + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + f = string_split(f, '/').back(); + params.model = "models/" + f; + } + } else if (params.model.empty()) { + params.model = DEFAULT_MODEL_PATH; + } +} + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -1339,10 +1379,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // short-hand to avoid specifying --hf-file -> default it to --model - if (!params.hf_repo.empty() && params.hf_file.empty()) { - params.hf_file = params.model; - } + gpt_params_handle_model_default(params); if (params.escape) { process_escapes(params.prompt); @@ -1481,8 +1518,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); + printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -1535,7 +1573,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --control-vector-layer-range START END\n"); printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: %s)\n", params.model.c_str()); + printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH); printf(" -md FNAME, --model-draft FNAME\n"); printf(" draft model for speculative decoding (default: unused)\n"); printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); @@ -1552,9 +1590,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -ptc N, --print-token-count N\n"); printf(" print token count every N tokens (default: %d)\n", params.n_print); + printf(" --check-tensors check model tensor data for invalid values\n"); printf("\n"); #ifndef LOG_DISABLE_LOGS log_print_usage(); @@ -1679,6 +1718,18 @@ std::vector string_split(std::string input, char separator) { return parts; } +std::string string_strip(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && std::isspace(str[start])) { + start++; + } + while (end > start && std::isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { {"top_k", llama_sampler_type::TOP_K}, @@ -1775,6 +1826,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; + mparams.check_tensors = params.check_tensors; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -1839,6 +1891,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -1869,59 +1922,75 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL -static bool llama_download_file(CURL * curl, const char * url, const char * path) { +static bool starts_with(const std::string & str, const std::string & prefix) { + // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + +static bool llama_download_file(const std::string & url, const std::string & path) { + + // Initialize libcurl + std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + if (!curl) { + fprintf(stderr, "%s: error initializing libcurl\n", __func__); + return false; + } + bool force_download = false; // Set the URL, allow to follow http redirection - curl_easy_setopt(curl, CURLOPT_URL, url); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif // Check if the file already exists locally struct stat model_file_info; - auto file_exists = (stat(path, &model_file_info) == 0); + auto file_exists = (stat(path.c_str(), &model_file_info) == 0); - // If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files - char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char etag_path[PATH_MAX] = {0}; - snprintf(etag_path, sizeof(etag_path), "%s.etag", path); - - char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char last_modified_path[PATH_MAX] = {0}; - snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path); + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; + std::string etag; + std::string last_modified; if (file_exists) { - auto * f_etag = fopen(etag_path, "r"); - if (f_etag) { - if (!fgets(etag, sizeof(etag), f_etag)) { - fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path); - } else { - fprintf(stderr, "%s: previous file found %s: %s\n", __func__, etag_path, etag); + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("url") && metadata["url"].is_string()) { + auto previous_url = metadata["url"].get(); + if (previous_url != url) { + fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + return false; + } + } + if (metadata.contains("etag") && metadata["etag"].is_string()) { + etag = metadata["etag"]; + } + if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) { + last_modified = metadata["lastModified"]; + } + } catch (const nlohmann::json::exception & e) { + fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + return false; } - fclose(f_etag); - } - - auto * f_last_modified = fopen(last_modified_path, "r"); - if (f_last_modified) { - if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) { - fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path); - } else { - fprintf(stderr, "%s: previous file found %s: %s\n", __func__, last_modified_path, - last_modified); - } - fclose(f_last_modified); } + } else { + fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str()); } // Send a HEAD request to retrieve the etag and last-modified headers struct llama_load_model_from_url_headers { - char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; - char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + std::string etag; + std::string last_modified; }; llama_load_model_from_url_headers headers; { @@ -1929,38 +1998,37 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; - // Convert header field name to lowercase - for (size_t i = 0; i < n_items && buffer[i] != ':'; ++i) { - buffer[i] = tolower(buffer[i]); - } + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - const char * etag_prefix = "etag: "; - if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { - strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF - } - - const char * last_modified_prefix = "last-modified: "; - if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) { - strncpy(headers->last_modified, buffer + strlen(last_modified_prefix), - n_items - strlen(last_modified_prefix) - 2); // Remove CRLF + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } } return n_items; }; - curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - CURLcode res = curl_easy_perform(curl); + CURLcode res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - curl_easy_cleanup(curl); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); return false; } long http_code = 0; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code != 200) { // HEAD not supported, we don't know if the file has changed // force trigger downloading @@ -1969,28 +2037,30 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path } } - // If the ETag or the Last-Modified headers are different: trigger a new download - bool should_download = !file_exists - || force_download - || (strlen(headers.etag) > 0 && strcmp(etag, headers.etag) != 0) - || (strlen(headers.last_modified) > 0 && strcmp(last_modified, headers.last_modified) != 0); + bool should_download = !file_exists || force_download; + if (!should_download) { + if (!etag.empty() && etag != headers.etag) { + fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } if (should_download) { - char path_temporary[PATH_MAX] = {0}; - snprintf(path_temporary, sizeof(path_temporary), "%s.downloadInProgress", path); + std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { - fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path); - if (remove(path) != 0) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path); + fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str()); return false; } } // Set the output file - auto * outfile = fopen(path_temporary, "wb"); + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb"), fclose); if (!outfile) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path); + fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; } @@ -1998,12 +2068,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { return fwrite(data, size, nmemb, (FILE *)fd); }; - curl_easy_setopt(curl, CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); // display download progress - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); // helper function to hide password in URL auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { @@ -2022,51 +2092,34 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path // start the download fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path, headers.etag, headers.last_modified); - auto res = curl_easy_perform(curl); + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + auto res = curl_easy_perform(curl.get()); if (res != CURLE_OK) { - fclose(outfile); - curl_easy_cleanup(curl); fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); return false; } long http_code = 0; - curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code < 200 || http_code >= 400) { - fclose(outfile); - curl_easy_cleanup(curl); fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); return false; } - // Clean up - fclose(outfile); + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); - // Write the new ETag to the .etag file - if (strlen(headers.etag) > 0) { - auto * etag_file = fopen(etag_path, "w"); - if (etag_file) { - fputs(headers.etag, etag_file); - fclose(etag_file); - fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); - } - } + // Write the updated JSON metadata file. + metadata.update({ + {"url", url}, + {"etag", headers.etag}, + {"lastModified", headers.last_modified} + }); + std::ofstream(metadata_path) << metadata.dump(4); + fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - // Write the new lastModified to the .etag file - if (strlen(headers.last_modified) > 0) { - auto * last_modified_file = fopen(last_modified_path, "w"); - if (last_modified_file) { - fputs(headers.last_modified, last_modified_file); - fclose(last_modified_file); - fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path, - headers.last_modified); - } - } - - if (rename(path_temporary, path) != 0) { - curl_easy_cleanup(curl); - fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path); + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } } @@ -2084,15 +2137,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - // Initialize libcurl - auto * curl = curl_easy_init(); - - if (!curl) { - fprintf(stderr, "%s: error initializing libcurl\n", __func__); - return NULL; - } - - if (!llama_download_file(curl, model_url, path_model)) { + if (!llama_download_file(model_url, path_model)) { return NULL; } @@ -2106,7 +2151,6 @@ struct llama_model * llama_load_model_from_url( auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params); if (!ctx_gguf) { fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model); - curl_easy_cleanup(curl); return NULL; } @@ -2118,8 +2162,6 @@ struct llama_model * llama_load_model_from_url( gguf_free(ctx_gguf); } - curl_easy_cleanup(curl); - if (n_split > 1) { char split_prefix[PATH_MAX] = {0}; char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0}; @@ -2150,11 +2192,7 @@ struct llama_model * llama_load_model_from_url( char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - auto * curl = curl_easy_init(); - bool res = llama_download_file(curl, split_url, split_path); - curl_easy_cleanup(curl); - - return res; + return llama_download_file(split_url, split_path); }, idx)); } @@ -2641,7 +2679,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); + fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); @@ -2676,6 +2714,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index ce2b66807..c8ee1ac78 100644 --- a/common/common.h +++ b/common/common.h @@ -31,6 +31,8 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) +#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" + // build info struct llama_control_vector_load_info; @@ -108,7 +110,7 @@ struct gpt_params { // // sampling parameters struct llama_sampling_params sparams; - std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download @@ -164,6 +166,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens @@ -177,15 +180,20 @@ struct gpt_params { bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run + bool check_tensors = false; // validate tensor data std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file + std::string mmproj = ""; // path to multimodal projector + std::vector image; // path to image file(s) }; +void gpt_params_handle_model_default(gpt_params & params); + +bool parse_kv_override(const char * data, std::vector & overrides); + bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params); @@ -209,6 +217,7 @@ bool validate_file_name(const std::string & filename); std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector sampler_types_from_chars(const std::string & names_string); std::vector string_split(std::string input, char separator); +std::string string_strip(const std::string & str); std::string sampler_type_to_name_string(llama_sampler_type sampler_type); // diff --git a/common/log.h b/common/log.h index e4edcac7d..2b2f0e455 100644 --- a/common/log.h +++ b/common/log.h @@ -234,7 +234,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std:: // INTERNAL, DO NOT USE // USE LOG() INSTEAD // -#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER) +#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) #define LOG_IMPL(str, ...) \ do { \ if (LOG_TARGET != nullptr) \ @@ -257,7 +257,7 @@ inline std::string log_filename_generator_impl(LogTriState multilog, const std:: // INTERNAL, DO NOT USE // USE LOG_TEE() INSTEAD // -#if !defined(_MSC_VER) or defined(__INTEL_LLVM_COMPILER) +#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) #define LOG_TEE_IMPL(str, ...) \ do { \ if (LOG_TARGET != nullptr) \ diff --git a/common/sampling.cpp b/common/sampling.cpp index d8f8f4cd4..df1b26a90 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -68,7 +68,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); + seed = std::random_device{}(); } ctx->rng.seed(seed); } diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py new file mode 100644 index 000000000..b019c1e3d --- /dev/null +++ b/convert-hf-to-gguf-update.py @@ -0,0 +1,279 @@ +# This script downloads the tokenizer models of the specified models from Huggingface and +# generates the get_vocab_base_pre() function for convert-hf-to-gguf.py +# +# This is necessary in order to analyze the type of pre-tokenizer used by the model and +# provide the necessary information to llama.cpp via the GGUF header in order to implement +# the same pre-tokenizer. +# +# ref: https://github.com/ggerganov/llama.cpp/pull/6920 +# +# Instructions: +# +# - Add a new model to the "models" list +# - Run the script with your huggingface token: +# +# python3 convert-hf-to-gguf-update.py +# +# - Copy-paste the generated get_vocab_base_pre() function into convert-hf-to-gguf.py +# - Update llama.cpp with the new pre-tokenizer if necessary +# +# TODO: generate tokenizer tests for llama.cpp +# TODO: automate the update of convert-hf-to-gguf.py +# + +import os +import requests +import sys +import json + +from hashlib import sha256 +from enum import IntEnum, auto + +class TOKENIZER_TYPE(IntEnum): + SPM = auto() + BPE = auto() + WPM = auto() + +# TODO: this string has to exercise as much pre-tokenizer functionality as possible +# will be updated with time - contributions welcome +chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + +if len(sys.argv) == 2: + token = sys.argv[1] +else: + print("Usage: python convert-hf-to-gguf-update.py ") + sys.exit(1) + +# TODO: add models here, base models preferred +models = [ + { "name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, + { "name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, + { "name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, + { "name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, + { "name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, + { "name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, + { "name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, + { "name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, + { "name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, + { "name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, + ] + +# make directory "models/tokenizers" if it doesn't exist +if not os.path.exists("models/tokenizers"): + os.makedirs("models/tokenizers") + +def download_file_with_auth(url, token, save_path): + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(url, headers=headers) + if response.status_code == 200: + with open(save_path, 'wb') as f: + f.write(response.content) + print(f"File {save_path} downloaded successfully") + else: + print(f"Failed to download file. Status code: {response.status_code}") + +# download the tokenizer models +for model in models: + name = model["name"] + repo = model["repo"] + tokt = model["tokt"] + + if not os.path.exists(f"models/tokenizers/{name}"): + os.makedirs(f"models/tokenizers/{name}") + else: + print(f"Directory models/tokenizers/{name} already exists - skipping") + continue + + print(f"Downloading {name} to models/tokenizers/{name}") + + url = f"{repo}/raw/main/config.json" + save_path = f"models/tokenizers/{name}/config.json" + download_file_with_auth(url, token, save_path) + + url = f"{repo}/raw/main/tokenizer.json" + save_path = f"models/tokenizers/{name}/tokenizer.json" + download_file_with_auth(url, token, save_path) + + if tokt == TOKENIZER_TYPE.SPM: + url = f"{repo}/resolve/main/tokenizer.model" + save_path = f"models/tokenizers/{name}/tokenizer.model" + download_file_with_auth(url, token, save_path) + + url = f"{repo}/raw/main/tokenizer_config.json" + save_path = f"models/tokenizers/{name}/tokenizer_config.json" + download_file_with_auth(url, token, save_path) + +# generate the source code for the convert-hf-to-gguf.py:get_vocab_base_pre() function: +# TODO: auto-update convert-hf-to-gguf.py with the generated function + +src_ifs = "" +for model in models: + name = model["name"] + tokt = model["tokt"] + + if tokt == TOKENIZER_TYPE.SPM: + continue + + # create the tokenizer + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + print(f"model: {name}") + print(f"tokt: {tokt}") + print(f"repo: {model['repo']}") + print(f"chktok: {chktok}") + print(f"chkhsh: {chkhsh}") + + # print the "pre_tokenizer" content from the tokenizer.json + with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + cfg = json.load(f) + pre_tokenizer = cfg["pre_tokenizer"] + print("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + + print(f"\n") + + src_ifs += f" if chkhsh == \"{chkhsh}\":\n" + src_ifs += f" # ref: {model['repo']}\n" + src_ifs += f" res = \"{name}\"\n" + +src_func = "" +src_func += " def get_vocab_base_pre(self, tokenizer) -> str:\n" +src_func += " # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that\n" +src_func += " # is specific for the BPE pre-tokenizer used by the model\n" +src_func += " # we will use this unique identifier to write a \"tokenizer.ggml.pre\" entry in the GGUF file which we can\n" +src_func += " # use in llama.cpp to implement the same pre-tokenizer\n" +src_func += "\n" +src_func += f" chktxt = {repr(chktxt)}\n" +src_func += "\n" +src_func += " chktok = tokenizer.encode(chktxt)\n" +src_func += " chkhsh = sha256(str(chktok).encode()).hexdigest()\n" +src_func += "\n" +src_func += " print(f\"chktok: {chktok}\")\n" +src_func += " print(f\"chkhsh: {chkhsh}\")\n" +src_func += "\n" +src_func += " res = None\n" +src_func += "\n" +src_func += " # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script\n" +src_func += " # or pull the latest version of the model from Huggingface\n" +src_func += " # don't edit the hashes manually!\n" +src_func += f"{src_ifs}\n" +src_func += " if res is None:\n" +src_func += " print(\"\\n\")\n" +src_func += " print(\"**************************************************************************************\")\n" +src_func += " print(\"** WARNING: The BPE pre-tokenizer was not recognized!\")\n" +src_func += " print(\"** There are 2 possible reasons for this:\")\n" +src_func += " print(\"** - the model has not been added to convert-hf-to-gguf-update.py yet\")\n" +src_func += " print(\"** - the pre-tokenization config has changed upstream\")\n" +src_func += " print(\"** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.\")\n" +src_func += " print(\"** ref: https://github.com/ggerganov/llama.cpp/pull/6920\")\n" +src_func += " print(\"**\")\n" +src_func += " print(f\"** chkhsh: {chkhsh}\")\n" +src_func += " print(\"**************************************************************************************\")\n" +src_func += " print(\"\\n\")\n" +src_func += " raise NotImplementedError(\"BPE pre-tokenizer was not recognized - update get_vocab_base_pre()\")\n" +src_func += "\n" +src_func += " print(f\"tokenizer.ggml.pre: {res}\")\n" +src_func += " print(f\"chkhsh: {chkhsh}\")\n" +src_func += "\n" +src_func += " return res\n" + +print(src_func) + +print("\n") +print("!!! Copy-paste the function above into convert-hf-to-gguf.py !!!") +print("\n") + +# generate tests for each tokenizer model + +tests = [ + "", + " ", + " ", + " ", + "\t", + "\n", + "\n\n", + "\n\n\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "нещо на Български", + "កាន់តែពិសេសអាចខលចេញ", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + " (", + "\n =", + "' era", + "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", + "3", + "33", + "333", + "3333", + "33333", + "333333", + "3333333", + "33333333", + "333333333", + chktxt, +] + +# write the tests to ./models/ggml-vocab-{name}.gguf.inp +# the format is: +# +# test0 +# __ggml_vocab_test__ +# test1 +# __ggml_vocab_test__ +# ... +# + +# with each model, encode all tests and write the results in ./models/ggml-vocab-{name}.gguf.out +# for each test, write the resulting tokens on a separate line + +for model in models: + name = model["name"] + tokt = model["tokt"] + + # create the tokenizer + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + + with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f: + for text in tests: + f.write(f"{text}") + f.write("\n__ggml_vocab_test__\n") + + with open(f"models/ggml-vocab-{name}.gguf.out", "w") as f: + for text in tests: + res = tokenizer.encode(text, add_special_tokens=False) + for r in res: + f.write(f" {r}") + f.write("\n") + + print(f"Tests for {name} written in ./models/ggml-vocab-{name}.gguf.*") + +# generate commands for creating vocab files + +print("\nRun the following commands to generate the vocab files for testing:\n") + +for model in models: + name = model["name"] + + print(f"python3 convert-hf-to-gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") + +print("\n") diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5763b6664..2f146d730 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -11,6 +11,7 @@ import sys from abc import ABC, abstractmethod from enum import IntEnum from pathlib import Path +from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast import numpy as np @@ -229,7 +230,7 @@ class Model(ABC): return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) # used for GPT-2 BPE and WordPiece vocabs - def get_basic_vocab(self) -> tuple[list[str], list[int]]: + def get_vocab_base(self) -> tuple[list[str], list[int], str]: tokens: list[str] = [] toktypes: list[int] = [] @@ -238,6 +239,8 @@ class Model(ABC): vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size + tokpre = self.get_vocab_base_pre(tokenizer) + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} added_vocab = tokenizer.get_added_vocab() @@ -255,11 +258,79 @@ class Model(ABC): tokens.append(reverse_vocab[i]) toktypes.append(gguf.TokenType.NORMAL) - return tokens, toktypes + return tokens, toktypes, tokpre + + # NOTE: this function is generated by convert-hf-to-gguf-update.py + # do not modify it manually! + # ref: https://github.com/ggerganov/llama.cpp/pull/6920 + def get_vocab_base_pre(self, tokenizer) -> str: + # encoding this string and hashing the resulting tokens would (hopefully) give us a unique identifier that + # is specific for the BPE pre-tokenizer used by the model + # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can + # use in llama.cpp to implement the same pre-tokenizer + + chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + + chktok = tokenizer.encode(chktxt) + chkhsh = sha256(str(chktok).encode()).hexdigest() + + print(f"chktok: {chktok}") + print(f"chkhsh: {chkhsh}") + + res = None + + # NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script + # or pull the latest version of the model from Huggingface + # don't edit the hashes manually! + if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": + # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B + res = "llama-bpe" + if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754": + # ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base + res = "deepseek-llm" + if chkhsh == "347715f544604f9118bb75ed199f68779f423cabb20db6de6f31b908d04d7821": + # ref: https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base + res = "deepseek-coder" + if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": + # ref: https://huggingface.co/tiiuae/falcon-7b + res = "falcon" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 + res = "bert-bge" + if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": + # ref: https://huggingface.co/mosaicml/mpt-7b + res = "mpt" + if chkhsh == "35d91631860c815f952d711435f48d356ebac988362536bed955d43bfa436e34": + # ref: https://huggingface.co/bigcode/starcoder2-3b + res = "starcoder" + if chkhsh == "3ce83efda5659b07b1ad37ca97ca5797ea4285d9b9ab0dc679e4a720c9da7454": + # ref: https://huggingface.co/openai-community/gpt2 + res = "gpt-2" + + if res is None: + print("\n") + print("**************************************************************************************") + print("** WARNING: The BPE pre-tokenizer was not recognized!") + print("** There are 2 possible reasons for this:") + print("** - the model has not been added to convert-hf-to-gguf-update.py yet") + print("** - the pre-tokenization config has changed upstream") + print("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.") + print("** ref: https://github.com/ggerganov/llama.cpp/pull/6920") + print("**") + print(f"** chkhsh: {chkhsh}") + print("**************************************************************************************") + print("\n") + raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + + print(f"tokenizer.ggml.pre: {res}") + print(f"chkhsh: {chkhsh}") + + return res def _set_vocab_gpt2(self) -> None: - tokens, toktypes = self.get_basic_vocab() + tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) @@ -277,6 +348,8 @@ class Model(ABC): vocab_size = hparams["vocab_size"] assert max(tokenizer.get_vocab().values()) < vocab_size + tokpre = self.get_vocab_base_pre(tokenizer) + merges = [] vocab = {} mergeable_ranks = tokenizer.mergeable_ranks @@ -304,6 +377,7 @@ class Model(ABC): toktypes.append(gguf.TokenType.NORMAL) self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) @@ -376,6 +450,7 @@ class Model(ABC): assert len(tokens) == vocab_size self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) @@ -397,6 +472,7 @@ class Model(ABC): assert len(tokens) == vocab.vocab_size self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) @@ -840,6 +916,7 @@ class XverseModel(Model): toktypes.append(toktype) self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) @@ -1335,6 +1412,11 @@ class LlamaModel(Model): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + # Same as super class, but permuting q_proj, k_proj def write_tensors(self): block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) @@ -2052,6 +2134,7 @@ class Phi3MiniModel(Model): toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) @@ -2294,6 +2377,7 @@ class InternLM2Model(Model): toktypes.append(SentencePieceTokenTypes.USER_DEFINED) self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) @@ -2443,7 +2527,7 @@ class BertModel(Model): self.gguf_writer.add_pooling_type(pooling_type) def set_vocab(self): - tokens, toktypes = self.get_basic_vocab() + tokens, toktypes, tokpre = self.get_vocab_base() self.vocab_size = len(tokens) # we need this to validate the size of the token_type embeddings @@ -2461,6 +2545,7 @@ class BertModel(Model): # add vocab to gguf self.gguf_writer.add_tokenizer_model("bert") + self.gguf_writer.add_tokenizer_pre(tokpre) self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) @@ -2482,6 +2567,10 @@ class BertModel(Model): print(f"Can not map tensor {name!r}") sys.exit() + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + data = data_torch.squeeze().numpy() n_dims = len(data.shape) new_dtype: type[np.floating[Any]] @@ -2638,6 +2727,9 @@ class MambaModel(Model): field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL) self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1])) + field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE) + self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1])) + field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST) self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) @@ -2843,6 +2935,7 @@ def parse_args() -> argparse.Namespace: help="directory containing model file", ) parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)") + parser.add_argument("--model-name", type=str, default=None, help="name of the model") return parser.parse_args() diff --git a/convert-llama-ggml-to-gguf.py b/convert-llama-ggml-to-gguf.py index cd9644fcb..5354b748b 100755 --- a/convert-llama-ggml-to-gguf.py +++ b/convert-llama-ggml-to-gguf.py @@ -281,6 +281,7 @@ class GGMLToGGUF: def add_vocab(self, gguf_writer): hp = self.model.hyperparameters gguf_writer.add_tokenizer_model('llama') + gguf_writer.add_tokenizer_pre('default') tokens = [] scores = [] toktypes = [] diff --git a/convert-persimmon-to-gguf.py b/convert-persimmon-to-gguf.py index 69be17f94..aba575426 100755 --- a/convert-persimmon-to-gguf.py +++ b/convert-persimmon-to-gguf.py @@ -99,6 +99,7 @@ def main(): tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir) gguf_writer.add_tokenizer_model('llama') + gguf_writer.add_tokenizer_pre('default') gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) gguf_writer.add_token_types(toktypes) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1e34de620..2924d8116 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,7 +32,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; @@ -41,6 +41,7 @@ int main(int argc, char ** argv) { int n_kv_max = 2048; int n_batch = 2048; int n_ubatch = 512; + bool flash_attn = false; int is_pp_shared = 0; int n_gpu_layers = 0; @@ -66,23 +67,27 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - is_pp_shared = std::atoi(argv[5]); + flash_attn = std::atoi(argv[5]); } if (argc >= 7) { - n_gpu_layers = std::atoi(argv[6]); + is_pp_shared = std::atoi(argv[6]); } if (argc >= 8) { - n_pp = parse_list(argv[7]); + n_gpu_layers = std::atoi(argv[7]); } if (argc >= 9) { - n_tg = parse_list(argv[8]); + n_pp = parse_list(argv[8]); } if (argc >= 10) { - n_pl = parse_list(argv[9]); + n_tg = parse_list(argv[9]); + } + + if (argc >= 11) { + n_pl = parse_list(argv[10]); } // init LLM @@ -108,10 +113,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = n_batch; - ctx_params.n_ubatch = n_ubatch; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = n_batch; + ctx_params.n_ubatch = n_ubatch; + ctx_params.flash_attn = flash_attn; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -169,7 +175,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 6c03b5426..84037c96d 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -24,6 +24,7 @@ struct Stats { }; struct StatParams { + std::string dataset; std::string ofile = "imatrix.dat"; int n_output_frequency = 10; int verbosity = 1; @@ -47,7 +48,7 @@ private: std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id // - void save_imatrix(const char * file_name) const; + void save_imatrix(const char * file_name, const char * dataset) const; void keep_imatrix(int ncall) const; }; @@ -200,7 +201,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } void IMatrixCollector::save_imatrix() const { - save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str()); + save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(), m_params.dataset.c_str()); } void IMatrixCollector::keep_imatrix(int ncall) const { @@ -208,24 +209,33 @@ void IMatrixCollector::keep_imatrix(int ncall) const { if (file_name.empty()) file_name = "imatrix.dat"; file_name += ".at_"; file_name += std::to_string(ncall); - save_imatrix(file_name.c_str()); + save_imatrix(file_name.c_str(), m_params.dataset.c_str()); } -void IMatrixCollector::save_imatrix(const char * fname) const { +void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) const { std::ofstream out(fname, std::ios::binary); int n_entries = m_stats.size(); - out.write((const char*)&n_entries, sizeof(n_entries)); - for (auto& p : m_stats) { + out.write((const char *) &n_entries, sizeof(n_entries)); + for (const auto & p : m_stats) { int len = p.first.size(); - out.write((const char*)&len, sizeof(len)); + out.write((const char *) &len, sizeof(len)); out.write(p.first.c_str(), len); - out.write((const char*)&p.second.ncall, sizeof(p.second.ncall)); + out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); int nval = p.second.values.size(); - out.write((const char*)&nval, sizeof(nval)); - if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float)); + out.write((const char *) &nval, sizeof(nval)); + if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float)); } + + // Write the number of call the matrix was computed with + out.write((const char *) &m_last_call, sizeof(m_last_call)); + + // Write the dataset name at the end of the file to later on specify it in quantize + int n_dataset = strlen(dataset); + out.write((const char *) &n_dataset, sizeof(n_dataset)); + out.write(dataset, n_dataset); + if (m_params.verbosity > 0) { - fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname); + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname); } } @@ -548,6 +558,29 @@ int main(int argc, char ** argv) { } } + gpt_params params; + params.n_batch = 512; + if (!gpt_params_parse(args.size(), args.data(), params)) { + return 1; + } + + params.logits_all = true; + params.n_batch = std::min(params.n_batch, params.n_ctx); + + print_build_info(); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + sparams.dataset = params.prompt_file; g_collector.set_parameters(std::move(sparams)); if (!combine_files.empty()) { @@ -586,28 +619,6 @@ int main(int argc, char ** argv) { } } - gpt_params params; - params.n_batch = 512; - if (!gpt_params_parse(args.size(), args.data(), params)) { - return 1; - } - - params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); - - print_build_info(); - - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); - } - llama_backend_init(); llama_numa_init(params.numa); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 7ef1cf2a7..d3a71382a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -175,6 +175,7 @@ struct cmd_params { std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; + std::vector flash_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -196,6 +197,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, + /* flash_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -221,6 +223,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -394,6 +397,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -478,6 +488,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -499,6 +510,7 @@ struct cmd_params_instance { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -533,6 +545,7 @@ struct cmd_params_instance { cparams.type_k = type_k; cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; return cparams; @@ -555,6 +568,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) + for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -573,6 +587,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -597,6 +612,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -634,6 +650,7 @@ struct test { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -658,6 +675,7 @@ struct test { split_mode = inst.split_mode; main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -732,7 +750,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", + "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -754,7 +772,7 @@ struct test { } if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -788,7 +806,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -956,6 +974,9 @@ struct markdown_printer : public printer { if (field == "no_kv_offload") { return "nkvo"; } + if (field == "flash_attn") { + return "fa"; + } if (field == "use_mmap") { return "mmap"; } @@ -1002,6 +1023,9 @@ struct markdown_printer : public printer { if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) { fields.emplace_back("no_kv_offload"); } + if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { + fields.emplace_back("flash_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index a44c6cd76..157a680b5 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -113,11 +113,11 @@ struct llava_context { }; static void show_additional_info(int /*argc*/, char ** argv) { - LOG_TEE("\n example usage: %s -m --mmproj --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG_TEE("\n example usage: %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) { +static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) { // load and preprocess the image llava_image_embed * embed = NULL; @@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para } params->prompt = remove_image_from_prompt(prompt); } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str()); + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str()); if (!embed) { - LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str()); + fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); return NULL; } } @@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ printf("\n"); } - -static struct llava_context * llava_init(gpt_params * params) { - const char * clip_path = params->mmproj.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - +static struct llama_model * llava_init(gpt_params * params) { llama_backend_init(); llama_numa_init(params->numa); @@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) { LOG_TEE("%s: error: unable to load model\n" , __func__); return NULL; } + return model; +} + +static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -286,24 +289,30 @@ int main(int argc, char ** argv) { show_additional_info(argc, argv); return 1; } - - auto ctx_llava = llava_init(¶ms); - if (ctx_llava == NULL) { - LOG_TEE("%s: error: failed to init llava\n", __func__); + auto model = llava_init(¶ms); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init llava model\n", __func__); return 1; } - auto image_embed = load_image(ctx_llava, ¶ms); - if (!image_embed) { - return 1; + for (auto & image : params.image) { + auto ctx_llava = llava_init_context(¶ms, model); + + auto image_embed = load_image(ctx_llava, ¶ms, image); + if (!image_embed) { + std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; + return 1; + } + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_print_timings(ctx_llava->ctx_llama); + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); } + llama_free_model(model); - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_print_timings(ctx_llava->ctx_llama); - - llava_image_embed_free(image_embed); - llava_free(ctx_llava); return 0; } diff --git a/examples/main-cmake-pkg/README.md b/examples/main-cmake-pkg/README.md index f599fbaec..edf20d8db 100644 --- a/examples/main-cmake-pkg/README.md +++ b/examples/main-cmake-pkg/README.md @@ -17,11 +17,9 @@ In this case, CLBlast was already installed so the CMake package is referenced i ```cmd git clone https://github.com/ggerganov/llama.cpp cd llama.cpp -mkdir build -cd build -cmake .. -DBUILD_SHARED_LIBS=OFF -DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH=C:/CLBlast/lib/cmake/CLBlast -G "Visual Studio 17 2022" -A x64 -cmake --build . --config Release -cmake --install . --prefix C:/LlamaCPP +cmake -B build -DBUILD_SHARED_LIBS=OFF -DLLAMA_CLBLAST=ON -DCMAKE_PREFIX_PATH=C:/CLBlast/lib/cmake/CLBlast -G "Visual Studio 17 2022" -A x64 +cmake --build build --config Release +cmake --install build --prefix C:/LlamaCPP ``` ### Build main-cmake-pkg @@ -29,9 +27,7 @@ cmake --install . --prefix C:/LlamaCPP ```cmd cd ..\examples\main-cmake-pkg -mkdir build -cd build -cmake .. -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/CLBlast/lib/cmake/CLBlast;C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64 -cmake --build . --config Release -cmake --install . --prefix C:/MyLlamaApp +cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/CLBlast/lib/cmake/CLBlast;C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64 +cmake --build build --config Release +cmake --install build --prefix C:/MyLlamaApp ``` diff --git a/examples/main/README.md b/examples/main/README.md index 649f4e0f3..e7a38743c 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -66,7 +66,7 @@ main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt In this section, we cover the most commonly used options for running the `main` program with the LLaMA models: -- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). +- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`; inferred from `--model-url` if set). - `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf). - `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. - `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4cb4b4119..7f42d26b6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -325,7 +325,7 @@ int main(int argc, char ** argv) { log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size()); // if we will use the cache for the full prompt without reaching the end of the cache, force - // reevaluation of the last token token to recalculate the cached logits + // reevaluation of the last token to recalculate the cached logits if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 57973e8b4..7e57d3652 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -24,7 +24,7 @@ #endif struct quantize_stats_params { - std::string model = "models/7B/ggml-model-f16.gguf"; + std::string model = DEFAULT_MODEL_PATH; bool verbose = false; bool per_layer_stats = false; bool print_histogram = false; diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index 6f374a2bd..6b977fde8 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -1,6 +1,6 @@ set(TARGET quantize) add_executable(${TARGET} quantize.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index bbb315a7e..80056165b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -9,7 +9,6 @@ #include #include #include -#include struct quant_option { std::string name; @@ -54,6 +53,10 @@ static const std::vector QUANT_OPTIONS = { { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, }; +static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { std::string ftype_str; @@ -114,7 +117,7 @@ static void usage(const char * executable) { exit(1); } -static void load_imatrix(const std::string & imatrix_file, std::unordered_map> & imatrix_data) { +static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { std::ifstream in(imatrix_file.c_str(), std::ios::binary); if (!in) { printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); @@ -161,18 +164,33 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map dataset_as_vec(dataset_len); + in.read(dataset_as_vec.data(), dataset_len); + imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end()); + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + } + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); + return m_last_call; } -static void prepare_imatrix(const std::string & imatrix_file, +static int prepare_imatrix(const std::string & imatrix_file, + std::string & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & imatrix_data) { + int m_last_call = -1; if (!imatrix_file.empty()) { - load_imatrix(imatrix_file, imatrix_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data); } if (imatrix_data.empty()) { - return; + return m_last_call; } if (!excluded_weights.empty()) { for (auto& name : excluded_weights) { @@ -198,6 +216,7 @@ static void prepare_imatrix(const std::string & imatrix_file, if (!imatrix_data.empty()) { printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size())); } + return m_last_call; } static ggml_type parse_ggml_type(const char * arg) { @@ -212,43 +231,6 @@ static ggml_type parse_ggml_type(const char * arg) { return result; } -static bool parse_kv_override(const char * data, std::vector & overrides) { - const char* sep = strchr(data, '='); - if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); - return false; - } - llama_model_kv_override kvo; - std::strncpy(kvo.key, data, sep - data); - kvo.key[sep - data] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } else { - fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); - return false; - } - } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); - return false; - } - overrides.emplace_back(std::move(kvo)); - return true; -} - int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -317,10 +299,43 @@ int main(int argc, char ** argv) { usage(argv[0]); } + std::string imatrix_dataset; std::unordered_map> imatrix_data; - prepare_imatrix(imatrix_file, included_weights, excluded_weights, imatrix_data); + int m_last_call = prepare_imatrix(imatrix_file, imatrix_dataset, included_weights, excluded_weights, imatrix_data); if (!imatrix_data.empty()) { params.imatrix = &imatrix_data; + { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + strncpy(kvo.val_str, imatrix_file.c_str(), 127); + kvo.val_str[127] = '\0'; + kv_overrides.emplace_back(std::move(kvo)); + } + if (!imatrix_dataset.empty()) { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + kvo.val_str[127] = '\0'; + kv_overrides.emplace_back(std::move(kvo)); + } + + { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = imatrix_data.size(); + kv_overrides.emplace_back(std::move(kvo)); + } + + if (m_last_call > 0) { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = m_last_call; + kv_overrides.emplace_back(std::move(kvo)); + } } if (!kv_overrides.empty()) { kv_overrides.emplace_back(); diff --git a/examples/server/README.md b/examples/server/README.md index 918ac1295..b96a4444a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -74,15 +74,18 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - Using `make`: ```bash - make + make server ``` - Using `CMake`: ```bash - cmake --build . --config Release + cmake -B build + cmake --build build --config Release -t server ``` + Binary is at `./build/bin/server` + ## Build with SSL `server` can also be built with SSL support using OpenSSL 3 @@ -99,10 +102,8 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - Using `CMake`: ```bash - mkdir build - cd build - cmake .. -DLLAMA_SERVER_SSL=ON - make server + cmake -B build -DLLAMA_SERVER_SSL=ON + cmake --build build --config Release -t server ``` ## Quick Start diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 6ca637bdd..86c5de101 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -268,6 +268,7 @@ def start_server_background(args): server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') + server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") diff --git a/examples/server/bench/script.js b/examples/server/bench/script.js index c4c486cdf..bdf4f5abc 100644 --- a/examples/server/bench/script.js +++ b/examples/server/bench/script.js @@ -90,7 +90,8 @@ export default function () { "model": model, "stream": true, "seed": 42, - "max_tokens": max_tokens + "max_tokens": max_tokens, + "stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS } const params = {method: 'POST', body: JSON.stringify(payload)}; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 63eb82c0e..bbdb462da 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1208,6 +1208,27 @@ struct server_context { LOG_VERBOSE("eos token found", {}); } + auto n_ctx_train = llama_n_ctx_train(model); + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 + && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + LOG_WARNING("n_predict is not set and self-context extend is disabled." + " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { + { "id_slot", slot.id }, + { "params.n_predict", slot.params.n_predict }, + { "slot.n_prompt_tokens", slot.n_prompt_tokens }, + { "slot.n_decoded", slot.n_decoded }, + { "slot.n_predict", slot.n_predict }, + { "n_slots", params.n_parallel }, + { "slot.n_ctx", slot.n_ctx }, + { "n_ctx", n_ctx }, + { "n_ctx_train", n_ctx_train }, + { "ga_n", slot.ga_n }, + }); + slot.truncated = true; + slot.stopped_limit = true; + slot.has_next_token = false; // stop prediction + } + LOG_VERBOSE("next token", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -2142,7 +2163,7 @@ struct server_context { }); // process the created batch of tokens - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); for (auto & slot : slots) { @@ -2333,7 +2354,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" disable KV offload\n"); } printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: %s)\n", params.model.c_str()); + printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH); printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); printf(" model download url (default: unused)\n"); printf(" -hfr REPO, --hf-repo REPO\n"); @@ -2357,6 +2378,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n"); @@ -2372,7 +2394,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n"); printf(" --chat-template JINJA_TEMPLATE\n"); @@ -2722,6 +2744,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; + } else if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) { invalid_param = true; @@ -2803,43 +2827,11 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } - char * sep = strchr(argv[i], '='); - if (sep == nullptr || sep - argv[i] >= 128) { - fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]); - invalid_param = true; - break; - } - - struct llama_model_kv_override kvo; - std::strncpy(kvo.key, argv[i], sep - argv[i]); - kvo.key[sep - argv[i]] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; - } else { - fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); - invalid_param = true; - break; - } - } else { + if (!parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; break; } - params.kv_overrides.push_back(kvo); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); @@ -2847,6 +2839,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } + gpt_params_handle_model_default(params); + if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index dcf1434f9..6f163ce04 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -5,7 +5,7 @@ Feature: llama.cpp server Background: Server startup Given a server listening on localhost:8080 And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf - And a model file ggml-model-f16.gguf + And a model file bert-bge-small.gguf And a model alias bert-bge-small And 42 as server seed And 2 slots diff --git a/ggml-backend.c b/ggml-backend.c index e91d97cd9..f5bdcf078 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1784,12 +1784,14 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { void ggml_backend_sched_reset(ggml_backend_sched_t sched) { // reset state for the next run - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT - memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); - memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); + if (!sched->is_reset) { + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT + memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); + memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); - sched->is_reset = true; + sched->is_reset = true; + } sched->is_alloc = false; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 298bd95a5..ba65d4747 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -16,6 +16,7 @@ static bool g_mul_mat_q = false; #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" +#include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -142,6 +143,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { @@ -2296,6 +2298,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2570,6 +2575,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 1c8311966..276c3fa23 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -142,6 +142,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) @@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -404,6 +415,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index b15e35782..75e50c985 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -5,16 +5,16 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x); + const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x); if (i >= k) { return; } const int64_t ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int64_t iqs = (i%qk)/qr; // quant index + const int64_t iybs = i - i%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; @@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h #if __CUDA_ARCH__ >= CC_PASCAL constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; - const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; + const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; const int * x0 = ((int *) vx) + blockIdx.x * nint; half2 * y2 = (half2 *) (y + i0); @@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; const int64_t ib = 8*i + ir; if (ib >= nb32) { return; @@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; const int64_t ib = 8*i + ir; if (ib >= nb32) { return; @@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; const uint8_t q = x[i].qs[32*n + l]; dst_t * y = yy + i*QK_K + 128*n; @@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); #else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); dst_t * y = yy + i*QK_K + 16*is + il; float dall = __low2half(x[i].dm); @@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; #if QK_K == 256 - const int r = threadIdx.x/4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16*is0 + 4*(threadIdx.x%4); - const int n = tid / 4; - const int j = tid - 4*n; + const int64_t r = threadIdx.x/4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16*is0 + 4*(threadIdx.x%4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; + int64_t is = 8*n + 2*j + is0; int shift = 2*j; int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : @@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); #else - const int tid = threadIdx.x; - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t is = tid/16; // 0 or 1 + const int64_t il = tid%16; // 0...15 + const int64_t im = il/8; // 0...1 + const int64_t in = il%8; // 0...7 dst_t * y = yy + i*QK_K + 16*is + il; @@ -227,15 +227,15 @@ template static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t is = 2*il; + const int64_t n = 4; dst_t * y = yy + i*QK_K + 64*il + n*ir; @@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t y[l +32] = d2 * (q[l] >> 4) - m2; } #else - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; const uint8_t * q = x[i].qs; dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; @@ -268,14 +268,14 @@ template static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 + const int64_t tid = threadIdx.x; + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 dst_t * y = yy + i*QK_K + 64*il + 2*ir; @@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; #else - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 + const int64_t im = tid/8; // 0...3 + const int64_t in = tid%8; // 0...7 + const int64_t is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; const float d = x[i].d; dst_t * y = yy + i*QK_K + tid; @@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_xxs * x = (const block_iq2_xxs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * aux8 = (const uint8_t *)q2; @@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds template static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_xs * x = (const block_iq2_xs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * q2 = x[i].qs + 4*ib; const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); @@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst template static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq2_s * x = (const block_iq2_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; @@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq3_xxs * x = (const block_iq3_xxs *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * q3 = x[i].qs + 8*ib; const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; @@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq3_s * x = (const block_iq3_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); @@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); @@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_m * x = (const block_iq1_m *) vx; - const int tid = threadIdx.x; + const int64_t tid = threadIdx.x; #if QK_K == 256 - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint16_t * sc = (const uint16_t *)x[i].scales; iq1m_scale_t scale; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; @@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[ib].qs + 4*il; const float d = (float)x[ib].d; @@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst #if QK_K != 64 template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq4_xs * x = (const block_iq4_xs *)vx; - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu new file mode 100644 index 000000000..df1e80068 --- /dev/null +++ b/ggml-cuda/fattn.cu @@ -0,0 +1,944 @@ +#include "common.cuh" +#include "fattn.cuh" + +#include + +#if FP16_MMA_AVAILABLE +#include +#endif + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nwarps*WARP_SIZE); + + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -HALF_MAX_HALF; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[threadIdx.x] = 0.0f; + } + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } + + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; + } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); + + if (tid < D) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } + } + + __syncthreads(); + } + + if (tid >= D) { + kqsum = 0.0f; + } + + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } + __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); + + if (tid >= D) { + return; + } + + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); + if (parallel_blocks == 1) { + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; + + if (parallel_blocks == 1 || tid != 0) { + return; + } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // FP16_MMA_AVAILABLE +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + ggml_cuda_set_device(ctx.device); + + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; +} diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh new file mode 100644 index 000000000..ad3ca7a8d --- /dev/null +++ b/ggml-cuda/fattn.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 9bda18e58..6ed225999 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,17 @@ #include "softmax.cuh" -template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32(half val) { + return __half2float(val); +} + +template +static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -28,7 +38,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols; + float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; float max_val = -INFINITY; @@ -40,10 +50,10 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; + const int64_t ix = (int64_t)rowx*ncols + col; + const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -109,12 +119,13 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f return; } - const int idst = rowx*ncols + col; + const int64_t idst = (int64_t)rowx*ncols + col; dst[idst] = vals[col] * inv_sum; } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float * void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + void * src2_d = nullptr; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_d = (void *)src2->data; } - soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + const half * src2_dd = (const half *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } else { + const float * src1_dd = (const float *)src1_d; + const float * src2_dd = (const float *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } } diff --git a/ggml-impl.h b/ggml-impl.h index 2087f7ded..94a1cc668 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -313,7 +313,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // defined(__ARM_NEON) -#if defined(__ARM_NEON) && !defined(__MSC_VER) +#if defined(__ARM_NEON) && !defined(_MSC_VER) #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6f..9a469821d 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-metal.m b/ggml-metal.m index baeab8e63..49e6b0709 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -46,8 +46,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -177,6 +179,14 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { } /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -459,172 +469,182 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { return NULL; \ } \ } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); + int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; @@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -2590,6 +2779,45 @@ static enum ggml_status ggml_metal_graph_compute( MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + MTLCommandBufferError error_code = [command_buffer error].code; + switch (error_code) { + case MTLCommandBufferErrorNone: + GGML_METAL_LOG_INFO("no error code reported\n"); + break; + case MTLCommandBufferErrorTimeout: + GGML_METAL_LOG_INFO("timeout\n"); + break; + case MTLCommandBufferErrorPageFault: + GGML_METAL_LOG_INFO("unserviceable page fault\n"); + break; + case MTLCommandBufferErrorOutOfMemory: + GGML_METAL_LOG_INFO("out of memory\n"); + break; + case MTLCommandBufferErrorInvalidResource: + GGML_METAL_LOG_INFO("invalid reference to resource\n"); + break; + case MTLCommandBufferErrorMemoryless: + GGML_METAL_LOG_INFO("GPU ran out of one or more of its internal resources that support memoryless render pass attachments\n"); + break; + case MTLCommandBufferErrorDeviceRemoved: + GGML_METAL_LOG_INFO("device removed\n"); + break; + case MTLCommandBufferErrorStackOverflow: + GGML_METAL_LOG_INFO("kernel function of tile shader used too many stack frames\n"); + break; + case MTLCommandBufferErrorAccessRevoked: + GGML_METAL_LOG_INFO("access to device revoked by system\n"); + break; + case MTLCommandBufferErrorInternal: + GGML_METAL_LOG_INFO("internal error\n"); + break; + default: + GGML_METAL_LOG_INFO("unknown error %lu\n", error_code); + break; + } + } + return GGML_STATUS_FAILED; } } @@ -2706,10 +2934,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2719,10 +2950,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2756,8 +2992,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + //ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2844,7 +3079,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2867,7 +3102,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2876,8 +3112,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 191880af1..3d4276ae0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -352,11 +352,12 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -375,10 +376,10 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device const float * ppos = src2 != src0 ? src2 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -456,11 +457,12 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -479,10 +481,10 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -499,7 +501,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -525,7 +527,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -562,6 +564,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_float8x8 mscale(scale); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + // mqk = mqk*scale + mask + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + float4 mm = (float4) mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml-quants.c b/ggml-quants.c index c0c26b124..b8fe01681 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12384,3 +12384,287 @@ void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) block_iq2_s * restrict y = vy; quantize_row_iq2_s_reference(x, y, k); } + +static bool validate_float(float f, size_t i) { + if (isinf(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +static bool isinf_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0; +} + +static bool isnan_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; +} + +static bool validate_fp16(ggml_fp16_t f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + +#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i)) { \ + return false; \ + } \ + } + +#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \ + return false; \ + } \ + } + +bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) { + if (type < 0 || type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + + if (nbytes % ggml_type_size(type) != 0) { + fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type); + return false; + } + + const size_t nb = nbytes/ggml_type_size(type); + + switch (type) { + case GGML_TYPE_F16: + { + const ggml_fp16_t * f = (const ggml_fp16_t *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 15 < nb; i += 16) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); + __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 16; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 7 < nb; i += 8) { + uint16x8_t v = vld1q_u16(f + i); + uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00)); + uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_fp16(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F32: + { + const float * f = (const float *) data; + size_t i = 0; +#if defined(__AVX2__) + for (; i + 7 < nb; i += 8) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); + __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#elif defined(__ARM_NEON) + for (; i + 3 < nb; i += 4) { + uint32x4_t v = vld1q_u32((const uint32_t *)f + i); + uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000)); + uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0); + if (mask) { + for (size_t j = 0; j < 4; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_F64: + { + const double * f = (const double *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(f[i], i)) { + return false; + } + } + } break; + case GGML_TYPE_Q4_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); + } break; + case GGML_TYPE_Q4_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m); + } break; + case GGML_TYPE_Q5_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb); + } break; + case GGML_TYPE_Q5_1: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m); + } break; + case GGML_TYPE_Q8_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); + } break; + case GGML_TYPE_Q2_K: + { + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); + } break; + case GGML_TYPE_Q3_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb); + } break; + case GGML_TYPE_Q4_K: + { + #ifdef GGML_QKK_64 + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin); + #endif + } break; + case GGML_TYPE_Q5_K: + { + #ifdef GGML_QKK_64 + VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb); + #else + VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin); + #endif + } break; + case GGML_TYPE_Q6_K: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb); + } break; + case GGML_TYPE_Q8_K: + { + const block_q8_K * q = (const block_q8_K *) data; + for (size_t i = 0; i < nb; ++i) { + if (!validate_float(q[i].d, i)) { + return false; + } + } + } break; + case GGML_TYPE_IQ1_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); + } break; + case GGML_TYPE_IQ1_M: + { + const block_iq1_m * q = (const block_iq1_m *) data; + for (size_t i = 0; i < nb; ++i) { + #if QK_K == 64 + if (!validate_fp16(q[i].d, i)) { + return false; + } + #else + iq1m_scale_t scale; + const uint16_t * sc = (const uint16_t *)q[i].scales; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + if (!validate_fp16(scale.f16, i)) { + return false; + } + #endif + } + } break; + case GGML_TYPE_IQ2_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb); + } break; + case GGML_TYPE_IQ2_XS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb); + } break; + case GGML_TYPE_IQ2_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb); + } break; + case GGML_TYPE_IQ3_XXS: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb); + } break; + + case GGML_TYPE_IQ3_S: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb); + } break; + case GGML_TYPE_IQ4_XS: + #if QK_K != 64 + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb); + } break; + #endif + // with QK_K == 64, iq4_xs is iq4_nl + case GGML_TYPE_IQ4_NL: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + // nothing to validate + break; + default: + { + fprintf(stderr, "%s: invalid type %d\n", __func__, type); + return false; + } + } + + return true; +} diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a9b310243..57fe4ea3d 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -13416,11 +13416,16 @@ void print_device_detail(int id, sycl::device &device, std::string device_type) version += std::to_string(prop.get_minor_version()); device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); + std::string name = std::string(prop.get_name()); + name = std::regex_replace(name, std::regex("\\(R\\)"), ""); + name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); - fprintf(stderr, "|%2d|%18s|%45s|%10s|%11d|%8d|%7d|%15lu|\n", id, device_type.c_str(), - prop.get_name(), version.c_str(), prop.get_max_compute_units(), + auto global_mem_size = prop.get_global_mem_size()/1000000; + + fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), + name.c_str(), version.c_str(), prop.get_max_compute_units(), prop.get_max_work_group_size(), prop.get_max_sub_group_size(), - prop.get_global_mem_size()); + global_mem_size, device.get_info().c_str()); } void ggml_backend_sycl_print_sycl_devices() { @@ -13428,9 +13433,10 @@ void ggml_backend_sycl_print_sycl_devices() { int device_count = dpct::dev_mgr::instance().device_count(); std::map DeviceNums; fprintf(stderr, "found %d SYCL devices:\n", device_count); - fprintf(stderr, "| | | |Compute |Max compute|Max work|Max sub| |\n"); - fprintf(stderr, "|ID| Device Type| Name|capability|units |group |group |Global mem size|\n"); - fprintf(stderr, "|--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|\n"); + fprintf(stderr, "| | | | |Max | |Max |Global | |\n"); + fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n"); + fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n"); + fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n"); for (int id = 0; id < device_count; ++id) { sycl::device device = dpct::dev_mgr::instance().get_device(id); sycl::backend backend = device.get_backend(); @@ -14738,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const ggml_tensor * src2 = dst->src[2]; + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14754,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, float * src2_dd = nullptr; sycl_pool_alloc src2_f; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 1736ab736..f712cdd5a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } diff --git a/ggml.c b/ggml.c index 7906a82f5..093588b4b 100644 --- a/ggml.c +++ b/ggml.c @@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -1046,7 +1046,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) @@ -1144,7 +1144,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", @@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4560,6 +4622,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5398,17 +5462,23 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F32); + GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { GGML_ASSERT(pos); } @@ -6217,6 +6287,59 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -12256,7 +12379,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -12279,19 +12402,31 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } + } } // ALiBi bias @@ -12299,8 +12434,14 @@ static void ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } @@ -14570,6 +14711,198 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(ggml_fp16_t)); + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + + ggml_vec_dot_f16(D, + &s, 0, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: + case GGML_PREC_F32: + { + // uses F32 accumulators + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -16377,6 +16710,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor); @@ -17389,6 +17726,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -18161,6 +18499,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -18564,6 +18903,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { @@ -20629,7 +20974,7 @@ static void gguf_free_kv(struct gguf_kv * kv) { } struct gguf_context * gguf_init_empty(void) { - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = GGUF_VERSION; @@ -20674,7 +21019,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p bool ok = true; - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); // read the header { @@ -20727,9 +21072,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the kv pairs { - ctx->kv = GGML_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv)); + const uint64_t n_kv = ctx->header.n_kv; - for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { + // header.n_kv will hold the actual value of pairs that were successfully read in the loop below + ctx->header.n_kv = 0; + ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv)); + + for (uint64_t i = 0; i < n_kv; ++i) { struct gguf_kv * kv = &ctx->kv[i]; //fprintf(stderr, "%s: reading kv %d\n", __func__, i); @@ -20786,7 +21135,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p return NULL; } - kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * gguf_type_size(kv->value.arr.type)); + kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type)); ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset); } break; @@ -20800,7 +21149,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p return NULL; } - kv->value.arr.data = GGML_MALLOC(kv->value.arr.n * sizeof(struct gguf_str)); + kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str)); for (uint64_t j = 0; j < kv->value.arr.n; ++j) { ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); @@ -20816,6 +21165,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p if (!ok) { break; } + + ctx->header.n_kv++; } if (!ok) { @@ -20828,7 +21179,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the tensor infos { - ctx->infos = GGML_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info)); for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { struct gguf_tensor_info * info = &ctx->infos[i]; @@ -20855,8 +21206,17 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + // TODO: return an error instead of crashing with GGML_ASSERT gguf_tensor_info_sanitize(info); + // make sure there is no duplicated tensor names + for (uint64_t j = 0; j < i; ++j) { + if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { + fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); + ok = false; + } + } + if (!ok) { fprintf(stderr, "%s: failed to read tensor info\n", __func__); fclose(file); @@ -21025,7 +21385,7 @@ void gguf_free(struct gguf_context * ctx) { GGML_FREE(ctx->infos); } - GGML_ALIGNED_FREE(ctx); + GGML_FREE(ctx); } const char * gguf_type_name(enum gguf_type type) { @@ -21336,7 +21696,7 @@ void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_ty ctx->kv[idx].type = GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = type; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_MALLOC(n*gguf_type_size(type)); + ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type)); memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type)); } @@ -21346,7 +21706,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** ctx->kv[idx].type = GGUF_TYPE_ARRAY; ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_MALLOC(n*sizeof(struct gguf_str)); + ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str)); for (int i = 0; i < n; i++) { struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; str->n = strlen(data[i]); @@ -21373,7 +21733,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { case GGUF_TYPE_ARRAY: { if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { - const char ** data = GGML_MALLOC(src->kv[i].value.arr.n*sizeof(char *)); + const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *)); for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; } @@ -21393,6 +21753,10 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { void gguf_add_tensor( struct gguf_context * ctx, const struct ggml_tensor * tensor) { + if (gguf_find_tensor(ctx, tensor->name) != -1) { + GGML_ASSERT(false && "duplicated tensor name"); + } + const int idx = ctx->header.n_tensors; ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); @@ -21461,7 +21825,7 @@ struct gguf_buf { static struct gguf_buf gguf_buf_init(size_t size) { struct gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : GGML_MALLOC(size), + /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size), /*buf.size =*/ size, /*buf.offset =*/ 0, }; diff --git a/ggml.h b/ggml.h index 2fe24d91a..1e4d1644f 100644 --- a/ggml.h +++ b/ggml.h @@ -482,6 +482,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -769,6 +770,8 @@ extern "C" { // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); + GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes); + // main GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); @@ -1727,6 +1730,25 @@ extern "C" { struct ggml_tensor * v, bool masked); +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d2f1de198..6d597bfd9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -72,6 +72,7 @@ class Keys: class Tokenizer: MODEL = "tokenizer.ggml.model" + PRE = "tokenizer.ggml.pre" LIST = "tokenizer.ggml.tokens" TOKEN_TYPE = "tokenizer.ggml.token_type" TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types @@ -940,6 +941,7 @@ KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK # tokenization KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL +KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 33afac552..2bdb15525 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -139,8 +139,13 @@ class GGUFReader: def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: - raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') - self.fields[field.name] = field + # TODO: add option to generate error on duplicate keys + # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') + + print(f'Warning: Duplicate key {field.name} at offset {field.offset}') + self.fields[field.name + '_{}'.format(field.offset)] = field + else: + self.fields[field.name] = field return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: @@ -234,8 +239,14 @@ class GGUFReader: def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: tensors = [] + tensor_names = set() # keep track of name to prevent duplicated tensors for field in fields: _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts + # check if there's any tensor having same name already in the list + tensor_name = str(bytes(name_data), encoding = 'utf-8') + if tensor_name in tensor_names: + raise ValueError(f'Found duplicated tensor with name {tensor_name}') + tensor_names.add(tensor_name) ggml_type = GGMLQuantizationType(raw_dtype[0]) n_elems = np.prod(dims) block_size, type_size = GGML_QUANT_SIZES[ggml_type] @@ -267,7 +278,7 @@ class GGUFReader: item_count = n_bytes item_type = np.uint8 tensors.append(ReaderTensor( - name = str(bytes(name_data), encoding = 'utf-8'), + name = tensor_name, tensor_type = ggml_type, shape = dims, n_elements = n_elems, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e3dbca454..089aece87 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -63,6 +63,7 @@ class GGUFWriter: self.kv_data_count = 0 self.ti_data = bytearray() self.ti_data_count = 0 + self.ti_names = set() self.use_temp_file = use_temp_file self.temp_file = None self.tensors = [] @@ -197,6 +198,10 @@ class GGUFWriter: if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') + if name in self.ti_names: + raise ValueError(f'Duplicated tensor name {name}') + self.ti_names.add(name) + encoded_name = name.encode("utf8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name @@ -422,6 +427,9 @@ class GGUFWriter: def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) + def add_tokenizer_pre(self, pre: str) -> None: + self.add_string(Keys.Tokenizer.PRE, pre) + def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: self.add_array(Keys.Tokenizer.LIST, tokens) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 3eeeec5dc..422abcb42 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -177,7 +177,7 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector } else { - output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, true, true); + output_tokens = ::llama_tokenize(llama_ctx_v4, str_to_tokenize, add_bos, true); if(add_bos) { llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model)); @@ -256,6 +256,15 @@ static int GetEosID(FileFormat file_format, int32_t n_vocab) } return eosID; } +static int GetEotID(FileFormat file_format) +{ + if(file_format == FileFormat::GGUF_GENERIC) + { + return llama_token_eot(&(llama_ctx_v4->model)); + } + return -1; +} + static float LowestLogit(const std::vector & logits) { int topid = std::min_element(logits.begin(), logits.end()) - logits.begin(); @@ -484,6 +493,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar } const llama_token eos = GetEosID(file_format,n_vocab); + const llama_token eot = GetEotID(file_format); std::vector, llama_partial_utf8>> candidates_decoded; std::vector candidates_grammar; @@ -491,7 +501,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; const std::string piece = FileFormatTokenizeID(id,file_format); - if (id == eos) { + if (id == eos || (id==eot && id!=-1)) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; } @@ -602,7 +612,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vectorstacks) { if (stack.empty()) { return; @@ -1601,12 +1611,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs) //if it tokenizes to a single token, AND it's a single non-printable special token, use that std::vector tmp; TokenizeString(stopper, tmp, file_format, false); + printf("\nPRINT TOK VEC:"); + print_tok_vec_str(tmp); if(tmp.size()==1) //tokenizes to exactly 1 special token { int specialid = tmp[0]; std::string tokenizedstr = FileFormatTokenizeID(specialid, file_format); + printf("\nTest %s",tokenizedstr.c_str()); if(tokenizedstr=="") //must NOT have a text representation { + printf("\nAdded %d",specialid); special_stop_sequence.push_back(specialid); } } @@ -2167,6 +2181,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) } unsigned int eosID = GetEosID(file_format, n_vocab); + unsigned int eotID = GetEotID(file_format); float * logitsPtr; float lowestLogit = 0; int btsize = banned_token_ids.size(); @@ -2196,6 +2211,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs) { // set the logit of the eos token to very low to avoid sampling it logitsPtr[eosID] = lowestLogit; + if(eotID!=-1) + { + logitsPtr[eotID] = lowestLogit; + } } if(btsize>0) { @@ -2257,7 +2276,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) printf("]\n"); } - if(inputs.allow_eos_token && id==eosID) + if(inputs.allow_eos_token && (id==eosID || (id==eotID && id!=-1))) { stopper_unused_tokens = remaining_tokens; if(allow_regular_prints) diff --git a/klite.embd b/klite.embd index e146ec7ba..573374ab0 100644 --- a/klite.embd +++ b/klite.embd @@ -8791,8 +8791,8 @@ Current version: 136 et = "<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; break; case "9": //llama 3 chat - st = "<|eot_id|><|start_header_id|>user<|end_header_id|>"; - et = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"; + st = "<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n"; + et = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n"; break; default: break; diff --git a/llama.cpp b/llama.cpp index ce5d65072..861d969ae 100644 --- a/llama.cpp +++ b/llama.cpp @@ -78,6 +78,7 @@ #include #include #include +#include #include #include #include @@ -110,7 +111,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 - // // logging // @@ -338,6 +338,7 @@ enum llm_kv { LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, @@ -414,6 +415,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, @@ -1869,7 +1871,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1959,6 +1961,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -2062,8 +2065,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2140,7 +2143,8 @@ struct llama_vocab { ttype type; }; - enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; std::unordered_map token_to_id; std::vector id_to_token; @@ -2365,11 +2369,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, ggml_type type_k, ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2380,6 +2387,7 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; // TODO: support mixed reccurent Transformer architectues // NOTE: (!a || b) is a logical implication (a -> b) @@ -2592,6 +2600,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -2912,6 +2924,7 @@ namespace GGUFMeta { case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; } return "unknown"; } @@ -2923,13 +2936,16 @@ namespace GGUFMeta { __func__, override_type_to_str(ovrd->tag), ovrd->key); switch (ovrd->tag) { case LLAMA_KV_OVERRIDE_TYPE_BOOL: { - LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false"); + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); } break; case LLAMA_KV_OVERRIDE_TYPE_INT: { - LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value); + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); } break; case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { - LLAMA_LOG_INFO("%.6f\n", ovrd->float_value); + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); } break; default: // Shouldn't be possible to end up here, but just in case... @@ -2948,7 +2964,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { - target = ovrd->bool_value; + target = ovrd->val_bool; return true; } return false; @@ -2958,7 +2974,7 @@ namespace GGUFMeta { static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { - target = ovrd->int_value; + target = ovrd->val_i64; return true; } return false; @@ -2968,7 +2984,7 @@ namespace GGUFMeta { static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { - target = ovrd->float_value; + target = ovrd->val_f64; return true; } return false; @@ -2977,12 +2993,11 @@ namespace GGUFMeta { template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override * ovrd) { - (void)target; - (void)ovrd; - if (!ovrd) { return false; } - // Currently, we should never end up here so it would be a bug if we do. - throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", - ovrd ? ovrd->key : "NULL")); + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; } static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { @@ -3015,6 +3030,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool check_tensors; llama_files files; llama_ftype ftype; @@ -3048,7 +3064,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3145,9 +3161,17 @@ struct llama_model_loader { fver = (enum llama_fver) gguf_get_version(meta); + std::set tensor_names; for (auto & w : weights) { n_elements += ggml_nelements(w.tensor); n_bytes += ggml_nbytes(w.tensor); + // make sure there is no duplicated tensor names + const std::string name(w.tensor->name); + auto found = tensor_names.find(name); + if (found != tensor_names.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name)); + } + tensor_names.insert(name); } LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", @@ -3254,6 +3278,7 @@ struct llama_model_loader { } this->use_mmap = use_mmap; + this->check_tensors = check_tensors; } ~llama_model_loader() { @@ -3512,6 +3537,10 @@ struct llama_model_loader { file->seek(w.offs, SEEK_SET); file->read_raw(cur->data, ggml_nbytes(cur)); } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } size_t size_done = 0; @@ -3528,6 +3557,8 @@ struct llama_model_loader { GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector> read_buf; + std::vector>> validation_result; + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); if (weight == nullptr) { @@ -3549,31 +3580,47 @@ struct llama_model_loader { if (bufs_mmap.count(weight->idx)) { buf_mmap = bufs_mmap.at(weight->idx); } + uint8_t * data = (uint8_t *) mapping->addr + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated if (buf_mmap && cur->data == nullptr) { - ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs); + ggml_backend_tensor_alloc(buf_mmap, cur, data); if (lmlocks) { const auto & lmlock = lmlocks->at(weight->idx); - lmlock->grow_to(weight->offs + ggml_nbytes(cur)); + lmlock->grow_to(weight->offs + n_size); } auto & mmap_used = mmaps_used[weight->idx]; mmap_used.first = std::min(mmap_used.first, weight->offs); mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); } else { - ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size); + ggml_backend_tensor_set(cur, data, 0, n_size); } } else { GGML_ASSERT(weight->idx < files.size()); const auto & file = files.at(weight->idx); if (ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); - file->read_raw(cur->data, ggml_nbytes(cur)); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } } else { - read_buf.resize(ggml_nbytes(cur)); + read_buf.resize(n_size); file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), ggml_nbytes(cur)); + file->read_raw(read_buf.data(), n_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } } @@ -3593,6 +3640,19 @@ struct llama_model_loader { size_done += n_size; } + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + // check if this is the last call and do final cleanup if (size_done >= size_data) { // unmap offloaded tensors and metadata @@ -4186,7 +4246,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -4210,11 +4270,13 @@ static void llm_load_vocab( // determine vocab type { - std::string tokenizer_name; + std::string tokenizer_model; + std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - if (tokenizer_name == "no_vocab") { + if (tokenizer_model == "no_vocab") { vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens @@ -4228,7 +4290,7 @@ static void llm_load_vocab( vocab.linefeed_id = -1; return; - } else if (tokenizer_name == "llama") { + } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens @@ -4273,9 +4335,27 @@ static void llm_load_vocab( if (add_space_prefix_keyidx != -1) { vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); } // The default value of add_space_prefix is true. - } else if (tokenizer_name == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else if (tokenizer_model == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = 100; + vocab.special_sep_id = 102; + vocab.special_pad_id = 0; + vocab.special_cls_id = 101; + vocab.special_mask_id = 103; + vocab.add_space_prefix = false; + } else { + if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + return; + } // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -4318,23 +4398,50 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; - } else if (tokenizer_name == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; + } - // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = 100; - vocab.special_sep_id = 102; - vocab.special_pad_id = 0; - vocab.special_cls_id = 101; - vocab.special_mask_id = 103; - vocab.add_space_prefix = false; + // for now, only BPE models have pre-tokenizers + if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "default") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + } else if ( + tokenizer_pre == "deepseek-llm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + } else if ( + tokenizer_pre == "deepseek-coder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + } else if ( + tokenizer_pre == "falcon") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - - vocab.type = LLAMA_VOCAB_TYPE_SPM; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } } @@ -6045,7 +6152,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -6174,37 +6281,47 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(kv.size == n_ctx); - // compute the transposed [n_tokens, n_embd] V matrix - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + + struct ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); + + v_cur = ggml_transpose(ctx, v_cur); + } cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } static struct ggml_tensor * llm_build_norm( @@ -6424,11 +6541,11 @@ static struct ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6436,12 +6553,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6459,72 +6576,100 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + struct ggml_tensor * cur; - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + // note: if this assert triggers, then some check has failed earlier + // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024. But koboldcpp will deal with it.") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.use_alibi) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } + + GGML_ASSERT(kv.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); } - GGML_ASSERT(kv.size == n_ctx); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); - ggml_build_forward_expand(graph, cur); cur = ggml_mul_mat(ctx, wo, cur); @@ -6543,6 +6688,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6552,7 +6698,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6566,12 +6711,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6613,6 +6758,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6659,6 +6806,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6773,15 +6921,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6811,20 +6975,26 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { + if (causal) { + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + } else { + // TODO: this will be needed for ALiBi-based BERT models + // https://github.com/ggerganov/llama.cpp/pull/6826 + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); + } cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; } struct ggml_tensor * build_inp_mean() { @@ -6930,9 +7100,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7070,9 +7240,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7177,9 +7347,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7297,9 +7467,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7422,9 +7592,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7574,9 +7744,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7686,9 +7856,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7890,9 +8060,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7986,9 +8156,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8279,9 +8449,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8410,14 +8580,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8559,9 +8730,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8677,9 +8848,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8790,9 +8961,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8904,9 +9075,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9059,9 +9230,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9176,9 +9347,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9289,9 +9460,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9392,9 +9563,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9499,9 +9670,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9615,9 +9786,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9732,9 +9903,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9862,9 +10033,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9983,9 +10154,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -10102,9 +10273,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10392,9 +10563,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10523,9 +10694,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10952,7 +11123,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; GGML_ASSERT(lctx.inp_KQ_pos); @@ -11334,7 +11507,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -11502,6 +11675,10 @@ static int llama_decode_internal( } } + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(lctx.sched); + return 0; } @@ -11527,7 +11704,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -12294,7 +12473,79 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - auto word_collection = bpe_gpt2_preprocess(text); + + std::vector word_collection; + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_BPE: + switch (vocab.type_pre) { + case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + word_collection = unicode_regex_split(text, { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + + // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + "\\s+$", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: + word_collection = unicode_regex_split(text, { + "[\r\n]", + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_FALCON: + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_MPT: + // TODO: MPT pre-tokenization regexes are unknown + // the following are close, but not exact. run the following: + // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf + GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); + word_collection = unicode_regex_split(text, { + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + case LLAMA_VOCAB_PRE_TYPE_STARCODER: + case LLAMA_VOCAB_PRE_TYPE_GPT2: + word_collection = unicode_regex_split(text, { + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }); + break; + default: + // default regex for BPE tokenization pre-processing + word_collection = unicode_regex_split(text, { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }); + break; + } + break; + default: + GGML_ASSERT(false); + break; + } symbols_final.clear(); @@ -12421,145 +12672,6 @@ private: work_queue.push(bigram); } - std::vector bpe_gpt2_preprocess(const std::string & text) { - std::vector bpe_words; - std::vector bpe_encoded_words; - - std::string token = ""; - // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ - bool collecting_numeric = false; - bool collecting_letter = false; - bool collecting_special = false; - bool collecting_whitespace_lookahead = false; - bool collecting = false; - - std::vector text_utf; - text_utf.reserve(text.size()); - bpe_words.reserve(text.size()); - bpe_encoded_words.reserve(text.size()); - - const auto cpts = unicode_cpts_from_utf8(text); - for (size_t i = 0; i < cpts.size(); ++i) - text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); - - for (int i = 0; i < (int)text_utf.size(); i++) { - const std::string & utf_char = text_utf[i]; - bool split_condition = false; - int bytes_remain = text_utf.size() - i; - // forward backward lookups - const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; - const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; - - // handling contractions - if (!split_condition && bytes_remain >= 2) { - // 's|'t|'m|'d - if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { - split_condition = true; - } - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next; - bpe_words.emplace_back(token); - token = ""; - i++; - continue; - } - } - if (!split_condition && bytes_remain >= 3) { - // 're|'ve|'ll - if (utf_char == "\'" && ( - (utf_char_next == "r" && utf_char_next_next == "e") || - (utf_char_next == "v" && utf_char_next_next == "e") || - (utf_char_next == "l" && utf_char_next_next == "l")) - ) { - split_condition = true; - } - if (split_condition) { - // current token + next token can be defined - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next + utf_char_next_next; - bpe_words.emplace_back(token); // the contraction - token = ""; - i += 2; - continue; - } - } - - if (!split_condition && !collecting) { - if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { - collecting_letter = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - collecting_numeric = true; - collecting = true; - } - else if ( - ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (!token.size() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) - ) { - collecting_special = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { - collecting_whitespace_lookahead = true; - collecting = true; - } - else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { - split_condition = true; - } - } - else if (!split_condition && collecting) { - if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { - split_condition = true; - } - else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { - split_condition = true; - } - else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { - split_condition = true; - } - else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - split_condition = true; - } - } - - if (utf_char_next == "") { - split_condition = true; // final - token += utf_char; - } - - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); - } - token = utf_char; - collecting = false; - collecting_letter = false; - collecting_numeric = false; - collecting_special = false; - collecting_whitespace_lookahead = false; - } - else { - token += utf_char; - } - } - - for (std::string & word : bpe_words) { - std::string encoded_token = ""; - for (char & c : word) { - encoded_token += unicode_byte_to_utf8(c); - } - bpe_encoded_words.emplace_back(encoded_token); - } - - return bpe_encoded_words; - } - const llama_vocab & vocab; std::vector symbols; @@ -12879,7 +12991,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } break; case LLAMA_VOCAB_TYPE_BPE: { - if (add_special && vocab.special_add_bos == 1) { + if (add_special && vocab.special_add_bos != 0) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); } @@ -14675,14 +14787,20 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { + if (nthread < 2) { + // single-thread + size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + if (!ggml_validate_row_data(new_type, new_data, new_size)) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; + } + std::mutex mutex; int64_t counter = 0; size_t new_size = 0; - if (nthread < 2) { - // single-thread - return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); - } - auto compute = [&mutex, &counter, &new_size, new_type, f32_data, new_data, chunk_size, + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix]() { const int64_t nrows_per_chunk = chunk_size / n_per_row; size_t local_size = 0; @@ -14697,7 +14815,17 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa } lock.unlock(); const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); - local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data + const size_t row_size = ggml_row_size(new_type, n_per_row); + void * this_data = (char *) new_data + first_row * row_size; + if (!ggml_validate_row_data(new_type, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } } }; for (int it = 0; it < nthread - 1; ++it) { @@ -14706,6 +14834,9 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa compute(); for (auto & w : workers) { w.join(); } workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } return new_size; } @@ -14768,7 +14899,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -14806,11 +14937,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s for (auto & o : overrides) { if (o.key[0] == 0) break; if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out, o.key, o.float_value); + gguf_set_val_f32(ctx_out, o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - gguf_set_val_i32(ctx_out, o.key, o.int_value); + gguf_set_val_i32(ctx_out, o.key, o.val_i64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out, o.key, o.bool_value); + gguf_set_val_bool(ctx_out, o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out, o.key, o.val_str); } else { LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); } @@ -15129,7 +15262,7 @@ static int llama_apply_lora_from_file_internal( std::unique_ptr ml; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr)); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr)); ml->init_mappings(/*prefetch*/ false); // no prefetching } @@ -15388,6 +15521,7 @@ struct llama_model_params llama_model_default_params() { /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, + /*.check_tensors =*/ false, }; #ifdef GGML_USE_METAL @@ -15424,6 +15558,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -15580,6 +15715,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -15587,12 +15723,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : @@ -15624,6 +15768,23 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && hparams.use_alibi) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); + cparams.flash_attn = false; + } + + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + +#ifdef GGML_USE_HIPBLAS + if (cparams.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); + cparams.flash_attn = false; + } +#endif + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } @@ -15631,6 +15792,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -15759,7 +15921,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -16358,6 +16520,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -16375,10 +16538,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } @@ -16524,11 +16691,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -16541,7 +16710,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16674,11 +16843,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state @@ -16688,6 +16861,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -16699,7 +16874,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16721,8 +16896,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -16982,28 +17155,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } - // For the values, they are transposed, so we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (int il = 0; il < (int)n_layer; ++il) { - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - data_ctx.write(&v_type_i, sizeof(v_type_i)); + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write element size - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - data_ctx.write(&v_size_el, sizeof(v_size_el)); + // Write row size of value + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - tmp_buf.resize(range_size * v_size_el); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + tmp_buf.resize(range_size * v_size_row); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } } return data_ctx.get_size_written(); @@ -17128,41 +17322,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } - // For each layer, read the values for each cell (transposed) - for (int il = 0; il < (int)n_layer; ++il) { - // Read type of value - int32_t v_type_i_ref; - memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); - inp += sizeof(v_type_i_ref); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return 0; - } + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } - // Read element size of value - size_t v_size_el_ref; - memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); - inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); - return 0; - } + // Read row size of value + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } } const size_t nread = inp - src; + return nread; } @@ -17974,9 +18202,9 @@ const char * llama_print_system_info(void) { s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | "; #ifdef GGML_USE_LLAMAFILE - s += "LAMMAFILE = 1 | "; + s += "LLAMAFILE = 1 | "; #else - s += "LAMMAFILE = 0 | "; + s += "LLAMAFILE = 0 | "; #endif return s.c_str(); diff --git a/llama.h b/llama.h index 678e5386d..5387c99c7 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 @@ -69,6 +69,18 @@ extern "C" { LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece }; + // pre-tokenization types + enum llama_vocab_pre_type { + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + }; + // note: these values should be synchronized with ggml_rope // TODO: maybe move this enum to ggml.h (ggml_rope_type) enum llama_rope_type { @@ -195,15 +207,19 @@ extern "C" { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, LLAMA_KV_OVERRIDE_TYPE_BOOL, + LLAMA_KV_OVERRIDE_TYPE_STR, }; struct llama_model_kv_override { - char key[128]; enum llama_model_kv_override_type tag; + + char key[128]; + union { - int64_t int_value; - double float_value; - bool bool_value; + int64_t val_i64; + double val_f64; + bool val_bool; + char val_str[128]; }; }; @@ -232,9 +248,10 @@ extern "C" { const struct llama_model_kv_override * kv_overrides; // Keep the booleans together to avoid misalignment during copy-by-value. - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool check_tensors; // validate model tensor data }; struct llama_context_params { @@ -270,6 +287,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -525,7 +543,7 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); diff --git a/sgemm.cpp b/sgemm.cpp index 531e12af3..4e0159804 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -50,7 +50,6 @@ #pragma GCC diagnostic ignored "-Wignored-attributes" #include "sgemm.h" -#include #include "ggml-impl.h" #include "ggml-quants.h" @@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) { template class tinyBLAS { public: - tinyBLAS(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: mc = 5; @@ -409,27 +408,27 @@ class tinyBLAS { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; - for (int l = 0; l < k; l += KN) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; l += KN) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(load(A + lda * (ii + i) + l), load(B + ldb * (jj + j) + l), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -437,10 +436,10 @@ class tinyBLAS { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -452,23 +451,23 @@ class tinyBLAS { template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int k, - const TA *A, int lda, - const block_q8_0 *B, int ldb, - float *C, int ldc, + tinyBLAS_Q0_ARM(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; nc = 3; @@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), @@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM { load_hi(B + ldb * (jj + j) + l))), unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM { const TA *const A; const block_q8_0 *const B; float *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM { template class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX2(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS_Q0_AVX2(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: mc = 4; @@ -714,22 +713,22 @@ class tinyBLAS_Q0_AVX2 { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), @@ -737,8 +736,8 @@ class tinyBLAS_Q0_AVX2 { _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C, - int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, assert(ldc >= m); assert(nth > 0); assert(ith < nth); - assert(1ll * lda * m <= 0x7fffffff); - assert(1ll * ldb * n <= 0x7fffffff); - assert(1ll * ldc * n <= 0x7fffffff); if (Ctype != GGML_TYPE_F32) return false; diff --git a/sgemm.h b/sgemm.h index da23b209c..f29747d0a 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,11 +1,13 @@ #pragma once +#include #include #ifdef __cplusplus extern "C" { #endif -bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, - void *, int, int, int, int, int, int, int); +bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, + const void *, int64_t, void *, int64_t, int, int, + int, int, int, int); #ifdef __cplusplus } diff --git a/unicode-data.cpp b/unicode-data.cpp index 22f8b0f0b..e6bafb3a9 100644 --- a/unicode-data.cpp +++ b/unicode-data.cpp @@ -1,4 +1,4 @@ -#include "unicode-data.h" +#include "unicode-data.h" #include #include diff --git a/unicode-data.h b/unicode-data.h index b99500b8f..cb9dd8aa5 100644 --- a/unicode-data.h +++ b/unicode-data.h @@ -12,5 +12,5 @@ extern const std::vector> unicode_ranges_accent_ma extern const std::vector> unicode_ranges_punctuation; extern const std::vector> unicode_ranges_symbol; extern const std::vector> unicode_ranges_control; -extern const std::multimap unicode_map_nfd; -extern const std::map unicode_map_lowercase; +extern const std::multimap unicode_map_nfd; +extern const std::map unicode_map_lowercase; diff --git a/unicode.cpp b/unicode.cpp index df8c5f581..f2ccda05f 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -5,11 +5,14 @@ #include #include #include +#include #include #include #include #include #include +#include +#include static std::string unicode_cpts_to_utf8(const std::vector & cps) { std::string result; @@ -53,23 +56,22 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) offset += 4; return result; } - throw std::invalid_argument("invalid string"); + throw std::invalid_argument("failed to convert utf8 to codepoint"); } -static std::vector unicode_cpt_to_utf16(uint32_t cp) { - std::vector result; - if (/* 0x0000 <= cp && */ cp <= 0xffff) { - result.emplace_back(cp); - } - else if (0x10000 <= cp && cp <= 0x10ffff) { - result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); - result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); - } - else { - throw std::invalid_argument("invalid cpt"); - } - return result; -} +//static std::vector unicode_cpt_to_utf16(uint32_t cp) { +// std::vector result; +// if (/* 0x0000 <= cp && */ cp <= 0xffff) { +// result.emplace_back(cp); +// return result; +// } +// if (0x10000 <= cp && cp <= 0x10ffff) { +// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); +// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); +// return result; +// } +// throw std::invalid_argument("failed to convert codepoint to utf16"); +//} //static std::vector unicode_cpts_to_utf16(const std::vector & cps) { // std::vector result; @@ -80,28 +82,28 @@ static std::vector unicode_cpt_to_utf16(uint32_t cp) { // return result; //} -static uint32_t cpt_from_utf16(const std::vector & utf16, size_t & offset) { - assert(offset < utf16.size()); - if (((utf16[0] >> 10) << 10) != 0xd800) { - auto result = utf16[offset + 0]; - offset += 1; - return result; - } - - if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { - throw std::invalid_argument("invalid character"); - } - - auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); - offset += 2; - return result; -} +//static uint32_t unicode_cpt_from_utf16(const std::vector & utf16, size_t & offset) { +// assert(offset < utf16.size()); +// if (((utf16[0] >> 10) << 10) != 0xd800) { +// auto result = utf16[offset + 0]; +// offset += 1; +// return result; +// } +// +// if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) { +// throw std::invalid_argument("invalid character"); +// } +// +// auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); +// offset += 2; +// return result; +//} //static std::vector unicode_cpts_from_utf16(const std::vector & utf16) { // std::vector result; // size_t offset = 0; // while (offset < utf16.size()) { -// result.push_back(cpt_from_utf16(utf16, offset)); +// result.push_back(unicode_cpt_from_utf16(utf16, offset)); // } // return result; //} @@ -194,34 +196,277 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } +static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { + std::wstring_convert> conv; + return conv.from_bytes(s); +} + +static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { + std::vector bpe_encoded_words; + for (const auto & word : bpe_words) { + std::string text_utf; + auto utf_word = unicode_cpts_from_utf8(word); + for (size_t i = 0; i < utf_word.size(); ++i) { + text_utf += unicode_cpt_to_utf8(utf_word[i]); + } + + std::string encoded_token; + for (char & c : text_utf) { + encoded_token += unicode_byte_to_utf8(c); + } + bpe_encoded_words.emplace_back(encoded_token); + } + return bpe_encoded_words; +} + +// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_gpt2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + size_t start = 0; + + const auto cpts = unicode_cpts_from_utf8(text); + + for (auto offset : offsets) { + std::string token; + + bool collecting_numeric = false; + bool collecting_letter = false; + bool collecting_special = false; + bool collecting_whitespace_lookahead = false; + bool collecting = false; + + std::vector text_utf; + text_utf.reserve(offset); + + for (size_t i = start; i < start + offset; ++i) { + text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i])); + } + + for (int i = 0; i < (int)text_utf.size(); i++) { + const std::string & utf_char = text_utf[i]; + bool split_condition = false; + int bytes_remain = text_utf.size() - i; + + // forward backward lookups + const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; + const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; + + // handling contractions + if (!split_condition && bytes_remain >= 2) { + // 's|'t|'m|'d + if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { + split_condition = true; + } + if (split_condition) { + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char + utf_char_next; + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + token = ""; + i++; + continue; + } + } + if (!split_condition && bytes_remain >= 3) { + // 're|'ve|'ll + if (utf_char == "\'" && ( + (utf_char_next == "r" && utf_char_next_next == "e") || + (utf_char_next == "v" && utf_char_next_next == "e") || + (utf_char_next == "l" && utf_char_next_next == "l")) + ) { + split_condition = true; + } + if (split_condition) { + // current token + next token can be defined + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char; + token += utf_char_next; + token += utf_char_next_next; + + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + token = ""; + i += 2; + continue; + } + } + + if (!split_condition && !collecting) { + if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { + collecting_letter = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + collecting_numeric = true; + collecting = true; + } + else if ( + ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || + (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) + ) { + collecting_special = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { + collecting_whitespace_lookahead = true; + collecting = true; + } + else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { + split_condition = true; + } + } + else if (!split_condition && collecting) { + if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) { + split_condition = true; + } + else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_DIGIT) { + split_condition = true; + } + else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_DIGIT || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { + split_condition = true; + } + else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + split_condition = true; + } + } + + if (utf_char_next == "") { + split_condition = true; // final + token += utf_char; + } + + if (split_condition) { + if (token.size()) { + bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size()); + } + token = utf_char; + collecting = false; + collecting_letter = false; + collecting_numeric = false; + collecting_special = false; + collecting_whitespace_lookahead = false; + } + else { + token += utf_char; + } + } + + start += offset; + } + + return bpe_offsets; +} + +// use std::wregex to split the text +static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { + std::wregex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); + std::wcregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::wcmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +// use std::regex to split the text +static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::regex expr(regex_expr); + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + size_t start = 0; + for (auto offset : offsets) { + std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); + std::cregex_iterator end; + + int64_t start_idx = 0; + while (it != end) { + std::cmatch match = *it; + if (match.position() > start_idx) { + bpe_offsets.emplace_back(match.position() - start_idx); + } + bpe_offsets.emplace_back(match.length()); + start_idx = match.position() + match.length(); + ++it; + } + + if (start_idx < (int64_t) offset) { + bpe_offsets.emplace_back(offset - start_idx); + } + start += offset; + } + + return bpe_offsets; +} + +static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { + std::vector bpe_offsets; + + (void)(text); + (void)(regex_expr); + (void)(offsets); + // TODO: this implementation is actually wrong, uncomment and run: + // make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf + //if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") { + // bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets); + //} + + return bpe_offsets; +} + // // interface // std::string unicode_cpt_to_utf8(uint32_t cp) { std::string result; + if (/* 0x00 <= cp && */ cp <= 0x7f) { result.push_back(cp); + return result; } - else if (0x80 <= cp && cp <= 0x7ff) { + if (0x80 <= cp && cp <= 0x7ff) { result.push_back(0xc0 | ((cp >> 6) & 0x1f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x800 <= cp && cp <= 0xffff) { + if (0x800 <= cp && cp <= 0xffff) { result.push_back(0xe0 | ((cp >> 12) & 0x0f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else if (0x10000 <= cp && cp <= 0x10ffff) { + if (0x10000 <= cp && cp <= 0x10ffff) { result.push_back(0xf0 | ((cp >> 18) & 0x07)); result.push_back(0x80 | ((cp >> 12) & 0x3f)); result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | (cp & 0x3f)); + return result; } - else { - throw std::invalid_argument("invalid codepoint"); - } - return result; + + throw std::invalid_argument("invalid codepoint"); } std::vector unicode_cpts_normalize_nfd(const std::vector & cpts) { @@ -275,3 +520,167 @@ char32_t unicode_tolower(char32_t cp) { auto it = unicode_map_lowercase.find(cp); return it == unicode_map_lowercase.end() ? cp : it->second; } + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { + // unicode categories + static const std::map k_ucat_enum = { + { "\\p{N}", CODEPOINT_TYPE_DIGIT }, + { "\\p{L}", CODEPOINT_TYPE_LETTER }, + { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, + }; + + static const std::map k_ucat_cpt = { + { CODEPOINT_TYPE_DIGIT, 0xD1 }, + { CODEPOINT_TYPE_LETTER, 0xD2 }, + { CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, + }; + + static const std::map k_ucat_map = { + { CODEPOINT_TYPE_DIGIT, "\x30-\x39" }, // 0-9 + { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + }; + + // compute collapsed codepoints only if needed by at least one regex + bool need_collapse = false; + for (auto & regex_expr : regex_exprs) { + // search for unicode categories + for (const auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + need_collapse = true; + break; + } + } + } + + const auto cpts = unicode_cpts_from_utf8(text); + + // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte + // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935 + std::string text_collapsed; + if (need_collapse) { + // collapse all unicode categories + text_collapsed.resize(cpts.size()); + + for (size_t i = 0; i < cpts.size(); ++i) { + // keep single-byte codepoints as is + if (cpts[i] < 128) { + text_collapsed[i] = cpts[i]; + continue; + } + + const int cpt_type = unicode_cpt_type(cpts[i]); + + if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { + text_collapsed[i] = k_ucat_cpt.at(cpt_type); + } else { + text_collapsed[i] = (char) 0xD0; // fallback + } + } + } + + std::vector bpe_offsets = { cpts.size() }; + + for (auto & regex_expr : regex_exprs) { + // first, see if we have an efficient custom regex implementation + auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); + + if (!tmp.empty()) { + bpe_offsets = std::move(tmp); + continue; + } + + // fallback to general-purpose std::regex / std::wregex + try { + // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category + // with the corresponding collapsed representation + bool use_collapsed = false; + for (auto & ucat : k_ucat_enum) { + if (std::string::npos != regex_expr.find(ucat.first)) { + use_collapsed = true; + break; + } + } + + if (use_collapsed) { + // sanity-check that the original regex does not contain any non-ASCII characters + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); + for (size_t i = 0; i < cpts_regex.size(); ++i) { + if (cpts_regex[i] >= 128) { + throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); + } + } + + // generate a collapsed representation of the regex + std::string regex_expr_collapsed; + + // track if we are inside [], because nested [] are not allowed + bool inside = false; + for (size_t i = 0; i < regex_expr.size(); ++i) { + if (regex_expr[i] == '[' && (i == 0 || regex_expr[i - 1] != '\\')) { + regex_expr_collapsed += '['; + inside = true; + continue; + } + + if (inside && regex_expr[i] == ']' && regex_expr[i - 1] != '\\') { + regex_expr_collapsed += ']'; + inside = false; + continue; + } + + if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + regex_expr[i + 1] == 'p' && + regex_expr[i + 2] == '{' && + regex_expr[i + 4] == '}') { + const std::string pat = regex_expr.substr(i, 5); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i += 4; + continue; + } + } + + regex_expr_collapsed += regex_expr[i]; + } + + //printf("text_collapsed: %s\n", text_collapsed.c_str()); + //printf("regex_expr_collapsed: %s\n", regex_expr_collapsed.c_str()); + bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); + } else { + // no unicode category used, we can use std::wregex directly + const std::wstring wtext = unicode_wstring_from_utf8(text); + const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + + //printf("text: %s\n", text.c_str()); + //printf("regex_expr: %s\n", regex_expr.c_str()); + bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets); + } + } catch (std::regex_error & e) { + fprintf(stderr, "Failed to process regex: '%s'\n", regex_expr.c_str()); + fprintf(stderr, "Regex error: %s\n", e.what()); + throw std::runtime_error("Failed to process regex"); + } + } + + std::vector bpe_words; + bpe_words.reserve(bpe_offsets.size()); // reserve memory for the approximate size + + size_t start = 0; + for (size_t & offset : bpe_offsets) { + bpe_words.emplace_back(); + for (size_t i = start; i < start + offset; ++i) { + bpe_words.back() += unicode_cpt_to_utf8(cpts[i]); + } + start += offset; + } + + return unicode_byte_encoding_process(bpe_words); +} diff --git a/unicode.h b/unicode.h index 6a0be393a..ce2bcef5a 100644 --- a/unicode.h +++ b/unicode.h @@ -24,5 +24,6 @@ int unicode_cpt_type(const std::string & utf8); std::string unicode_byte_to_utf8(uint8_t byte); uint8_t unicode_utf8_to_byte(const std::string & utf8); -// simple tolower that only implements one-to-one mapping, not one-to-many char32_t unicode_tolower(char32_t cp); + +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);