Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	common/common.cpp
#	examples/batched-bench/batched-bench.cpp
#	examples/batched/batched.cpp
#	examples/export-lora/export-lora.cpp
#	examples/gritlm/gritlm.cpp
#	examples/parallel/parallel.cpp
#	examples/passkey/passkey.cpp
#	examples/speculative-simple/speculative-simple.cpp
#	examples/speculative/speculative.cpp
#	ggml/src/ggml-cann/CMakeLists.txt
#	ggml/src/ggml-cann/acl_tensor.cpp
#	ggml/src/ggml-cann/acl_tensor.h
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/aclnn_ops.h
#	ggml/src/ggml-vulkan/CMakeLists.txt
#	tests/test-arg-parser.cpp
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-04-03 18:57:49 +08:00
commit 103d60ed2c
43 changed files with 1509 additions and 1129 deletions

View file

@ -1,10 +1,21 @@
#include "gguf.h" // for reading GGUF splits
#include "arg.h" #include "arg.h"
#include "common.h"
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
#include "chat.h" #include "chat.h"
#include "build-info.h" #include "build-info.h"
// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
@ -15,6 +26,14 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
//#define LLAMA_USE_CURL
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <future>
#endif
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
@ -126,47 +145,549 @@ std::string common_arg::to_string() {
return ss.str(); return ss.str();
} }
//
// downloader
//
struct common_hf_file_res {
std::string repo; // repo name with ":tag" removed
std::string ggufFile;
std::string mmprojFile;
};
#ifdef LLAMA_USE_CURL
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) {
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!bearer_token.empty()) {
std::string auth_header = "Authorization: Bearer " + bearer_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
}
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) {
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
}
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
force_download = true;
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
return true;
}
// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token);
}, item));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return false;
}
}
return true;
}
static bool common_download_model(
const common_params_model & model,
const std::string & bearer_token) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
}
if (!common_download_file_single(model.url, model.path, bearer_token)) {
return false;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
return false;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
return false;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
return false;
}
}
std::vector<std::pair<std::string, std::string>> urls;
for (int idx = 1; idx < n_split; idx++) {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
if (std::string(split_path) == model.path) {
continue; // skip the already downloaded file
}
urls.push_back({split_url, split_path});
}
// Download in parallel
common_download_file_multiple(urls, bearer_token);
}
return true;
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!bearer_token.empty()) {
std::string auth_header = "Authorization: Bearer " + bearer_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
std::string ggufFile = "";
std::string mmprojFile = "";
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
// extract ggufFile.rfilename in json, using regex
{
std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\"");
std::smatch match;
if (std::regex_search(res_str, match, pattern)) {
ggufFile = match[1].str();
}
}
// extract mmprojFile.rfilename in json, using regex
{
std::regex pattern("\"mmprojFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\"");
std::smatch match;
if (std::regex_search(res_str, match, pattern)) {
mmprojFile = match[1].str();
}
}
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (ggufFile.empty()) {
throw std::runtime_error("error: model does not have ggufFile");
}
return { hf_repo, ggufFile, mmprojFile };
}
#else
static bool common_download_file_single(const std::string &, const std::string &, const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from internet\n");
return false;
}
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> &, const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return false;
}
static bool common_download_model(
const common_params_model &,
const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return false;
}
static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) {
LOG_ERR("error: built without CURL, cannot download model from the internet\n");
return {};
}
#endif // LLAMA_USE_CURL
// //
// utils // utils
// //
static void common_params_handle_model_default( static void common_params_handle_model(
std::string & model, struct common_params_model & model,
const std::string & model_url, const std::string & bearer_token,
std::string & hf_repo, const std::string & model_path_default,
std::string & hf_file, bool is_mmproj = false) { // TODO: move is_mmproj to an enum when we have more files?
const std::string & hf_token, // handle pre-fill default model path and url based on hf_repo and hf_file
const std::string & model_default) { {
if (!hf_repo.empty()) { if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model // short-hand to avoid specifying --hf-file -> default it to --model
if (hf_file.empty()) { if (model.hf_file.empty()) {
if (model.empty()) { if (model.path.empty()) {
auto auto_detected = common_get_hf_file(hf_repo, hf_token); auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token);
if (auto_detected.first.empty() || auto_detected.second.empty()) { if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
exit(1); // built without CURL, error message already printed exit(1); // built without CURL, error message already printed
} }
hf_repo = auto_detected.first; model.hf_repo = auto_detected.repo;
hf_file = auto_detected.second; model.hf_file = is_mmproj ? auto_detected.mmprojFile : auto_detected.ggufFile;
} else { } else {
hf_file = model; model.hf_file = model.path;
} }
} }
// TODO: allow custom host
model.url = "https://huggingface.co/" + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes) // make sure model path is present (for caching purposes)
if (model.empty()) { if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs // this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = hf_repo + "_" + hf_file; std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename // to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_"); string_replace_all(filename, "/", "_");
model = fs_get_cache_file(filename); model.path = fs_get_cache_file(filename);
} }
} else if (!model_url.empty()) {
if (model.empty()) { } else if (!model.url.empty()) {
auto f = string_split<std::string>(model_url, '#').front(); if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front(); f = string_split<std::string>(f, '?').front();
model = fs_get_cache_file(string_split<std::string>(f, '/').back()); model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (model.path.empty()) {
model.path = model_path_default;
}
}
// then, download it if needed
if (!model.url.empty()) {
bool ok = common_download_model(model, bearer_token);
if (!ok) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1);
} }
} else if (model.empty()) {
model = model_default;
} }
} }
@ -301,10 +822,16 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
} }
// TODO: refactor model params in a common struct common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH);
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); common_params_handle_model(params.speculative.model, params.hf_token, "");
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); common_params_handle_model(params.vocoder.model, params.hf_token, "");
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");
// allow --mmproj to be set from -hf
// assuming that mmproj is always in the same repo as text model
if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) {
params.mmproj.hf_repo = params.model.hf_repo;
}
common_params_handle_model(params.mmproj, params.hf_token, "", true);
if (params.escape) { if (params.escape) {
string_process_escapes(params.prompt); string_process_escapes(params.prompt);
@ -323,6 +850,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.kv_overrides.back().key[0] = 0; params.kv_overrides.back().key[0] = 0;
} }
if (!params.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
if (params.reranking && params.embedding) { if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
} }
@ -1562,7 +2093,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--mmproj"}, "FILE", {"--mmproj"}, "FILE",
"path to a multimodal projector file for LLaVA. see examples/llava/README.md", "path to a multimodal projector file for LLaVA. see examples/llava/README.md",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.mmproj = value; params.mmproj.path = value;
}
).set_examples({LLAMA_EXAMPLE_LLAVA}));
add_opt(common_arg(
{"--mmproj-url"}, "URL",
"URL to a multimodal projector file for LLaVA. see examples/llava/README.md",
[](common_params & params, const std::string & value) {
params.mmproj.url = value;
} }
).set_examples({LLAMA_EXAMPLE_LLAVA})); ).set_examples({LLAMA_EXAMPLE_LLAVA}));
add_opt(common_arg( add_opt(common_arg(
@ -1648,6 +2186,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
exit(0); exit(0);
} }
)); ));
add_opt(common_arg(
{"--override-tensor", "-ot"}, "<tensor name pattern>=<buffer type>,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
/* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
if (buft_list.empty()) {
// enumerate all the devices and add their buffer types to the list
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list[ggml_backend_buft_name(buft)] = buft;
}
}
}
for (const auto & override : string_split<std::string>(value, ',')) {
std::string::size_type pos = override.find('=');
if (pos == std::string::npos) {
throw std::invalid_argument("invalid value");
}
std::string tensor_name = override.substr(0, pos);
std::string buffer_type = override.substr(pos + 1);
if (buft_list.find(buffer_type) == buft_list.end()) {
printf("Available buffer types:\n");
for (const auto & it : buft_list) {
printf(" %s\n", ggml_backend_buft_name(it.second));
}
throw std::invalid_argument("unknown buffer type");
}
// FIXME: this leaks memory
params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
}
}
));
add_opt(common_arg( add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM", "number of layers to store in VRAM",
@ -1791,14 +2364,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH
), ),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.model = value; params.model.path = value;
} }
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
add_opt(common_arg( add_opt(common_arg(
{"-mu", "--model-url"}, "MODEL_URL", {"-mu", "--model-url"}, "MODEL_URL",
"model download url (default: unused)", "model download url (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.model_url = value; params.model.url = value;
} }
).set_env("LLAMA_ARG_MODEL_URL")); ).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(common_arg( add_opt(common_arg(
@ -1807,35 +2380,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"example: unsloth/phi-4-GGUF:q4_k_m\n" "example: unsloth/phi-4-GGUF:q4_k_m\n"
"(default: unused)", "(default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.hf_repo = value; params.model.hf_repo = value;
} }
).set_env("LLAMA_ARG_HF_REPO")); ).set_env("LLAMA_ARG_HF_REPO"));
add_opt(common_arg( add_opt(common_arg(
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]", {"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)", "Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.speculative.hf_repo = value; params.speculative.model.hf_repo = value;
} }
).set_env("LLAMA_ARG_HFD_REPO")); ).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg( add_opt(common_arg(
{"-hff", "--hf-file"}, "FILE", {"-hff", "--hf-file"}, "FILE",
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.hf_file = value; params.model.hf_file = value;
} }
).set_env("LLAMA_ARG_HF_FILE")); ).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg( add_opt(common_arg(
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]", {"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
"Hugging Face model repository for the vocoder model (default: unused)", "Hugging Face model repository for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.vocoder.hf_repo = value; params.vocoder.model.hf_repo = value;
} }
).set_env("LLAMA_ARG_HF_REPO_V")); ).set_env("LLAMA_ARG_HF_REPO_V"));
add_opt(common_arg( add_opt(common_arg(
{"-hffv", "--hf-file-v"}, "FILE", {"-hffv", "--hf-file-v"}, "FILE",
"Hugging Face model file for the vocoder model (default: unused)", "Hugging Face model file for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.vocoder.hf_file = value; params.vocoder.model.hf_file = value;
} }
).set_env("LLAMA_ARG_HF_FILE_V")); ).set_env("LLAMA_ARG_HF_FILE_V"));
add_opt(common_arg( add_opt(common_arg(
@ -2455,7 +3028,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-md", "--model-draft"}, "FNAME", {"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)", "draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.speculative.model = value; params.speculative.model.path = value;
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
@ -2463,7 +3036,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-mv", "--model-vocoder"}, "FNAME", {"-mv", "--model-vocoder"}, "FNAME",
"vocoder model for audio generation (default: unused)", "vocoder model for audio generation (default: unused)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
params.vocoder.model = value; params.vocoder.model.path = value;
} }
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg( add_opt(common_arg(
@ -2486,10 +3059,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--tts-oute-default"}, {"--tts-oute-default"},
string_format("use default OuteTTS models (note: can download weights from the internet)"), string_format("use default OuteTTS models (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; params.model.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF";
params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; params.model.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf";
params.vocoder.hf_repo = "ggml-org/WavTokenizer"; params.vocoder.model.hf_repo = "ggml-org/WavTokenizer";
params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf"; params.vocoder.model.hf_file = "WavTokenizer-Large-75-F16.gguf";
} }
).set_examples({LLAMA_EXAMPLE_TTS})); ).set_examples({LLAMA_EXAMPLE_TTS}));
@ -2497,8 +3070,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--embd-bge-small-en-default"}, {"--embd-bge-small-en-default"},
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"), string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.hf_file = "bge-small-en-v1.5-q8_0.gguf"; params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2; params.embd_normalize = 2;
params.n_ctx = 512; params.n_ctx = 512;
@ -2511,8 +3084,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--embd-e5-small-en-default"}, {"--embd-e5-small-en-default"},
string_format("use default e5-small-v2 model (note: can download weights from the internet)"), string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.hf_file = "e5-small-v2-q8_0.gguf"; params.model.hf_file = "e5-small-v2-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2; params.embd_normalize = 2;
params.n_ctx = 512; params.n_ctx = 512;
@ -2525,8 +3098,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--embd-gte-small-default"}, {"--embd-gte-small-default"},
string_format("use default gte-small model (note: can download weights from the internet)"), string_format("use default gte-small model (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.hf_file = "gte-small-q8_0.gguf"; params.model.hf_file = "gte-small-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2; params.embd_normalize = 2;
params.n_ctx = 512; params.n_ctx = 512;
@ -2539,8 +3112,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--fim-qwen-1.5b-default"}, {"--fim-qwen-1.5b-default"},
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"), string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99; params.n_gpu_layers = 99;
params.flash_attn = true; params.flash_attn = true;
@ -2555,8 +3128,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--fim-qwen-3b-default"}, {"--fim-qwen-3b-default"},
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"), string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99; params.n_gpu_layers = 99;
params.flash_attn = true; params.flash_attn = true;
@ -2571,8 +3144,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--fim-qwen-7b-default"}, {"--fim-qwen-7b-default"},
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"), string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99; params.n_gpu_layers = 99;
params.flash_attn = true; params.flash_attn = true;
@ -2587,10 +3160,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--fim-qwen-7b-spec"}, {"--fim-qwen-7b-spec"},
string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99; params.speculative.n_gpu_layers = 99;
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99; params.n_gpu_layers = 99;
@ -2606,10 +3179,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--fim-qwen-14b-spec"}, {"--fim-qwen-14b-spec"},
string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"),
[](common_params & params) { [](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99; params.speculative.n_gpu_layers = 99;
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99; params.n_gpu_layers = 99;

View file

@ -55,47 +55,11 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#endif #endif
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <future>
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#if defined(LLAMA_USE_CURL)
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
// //
// CPU utils // CPU utils
// //
@ -904,22 +868,14 @@ std::string fs_get_cache_file(const std::string & filename) {
// //
// Model utils // Model utils
// //
struct common_init_result common_init_from_params(common_params & params) { struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams; common_init_result iparams;
auto mparams = common_model_params_to_llama(params); auto mparams = common_model_params_to_llama(params);
llama_model * model = nullptr; llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_model_load_from_file(params.model.c_str(), mparams);
}
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
return iparams; return iparams;
} }
@ -954,7 +910,7 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
llama_model_free(model); llama_model_free(model);
return iparams; return iparams;
} }
@ -1093,15 +1049,18 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (!params.devices.empty()) { if (!params.devices.empty()) {
mparams.devices = params.devices.data(); mparams.devices = params.devices.data();
} }
if (params.n_gpu_layers != -1) { if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers; mparams.n_gpu_layers = params.n_gpu_layers;
} }
mparams.main_gpu = params.main_gpu; mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode; mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap; mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock; mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors; mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) { if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL; mparams.kv_overrides = NULL;
} else { } else {
@ -1109,6 +1068,13 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.kv_overrides = params.kv_overrides.data(); mparams.kv_overrides = params.kv_overrides.data();
} }
if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
return mparams; return mparams;
} }
@ -1168,451 +1134,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
return tpp; return tpp;
} }
#ifdef LLAMA_USE_CURL
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
}
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) {
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
}
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
force_download = true;
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
return true;
}
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (model_url.empty()) {
LOG_ERR("%s: invalid model_url\n", __func__);
return NULL;
}
if (!common_download_file(model_url, local_path, hf_token)) {
return NULL;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
return NULL;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
return NULL;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
return NULL;
}
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
return common_download_file(split_url, split_path, hf_token);
}, idx));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return NULL;
}
}
}
return llama_model_load_from_file(local_path.c_str(), params);
}
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += remote_path;
return common_load_model_from_url(model_url, local_path, hf_token, params);
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
json model_info;
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
model_info = json::parse(res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (!model_info.contains("ggufFile")) {
throw std::runtime_error("error: model does not have ggufFile");
}
json & gguf_file = model_info.at("ggufFile");
if (!gguf_file.contains("rfilename")) {
throw std::runtime_error("error: ggufFile does not have rfilename");
}
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
}
#else
struct llama_model * common_load_model_from_url(
const std::string & /*model_url*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}
struct llama_model * common_load_model_from_hf(
const std::string & /*repo*/,
const std::string & /*remote_path*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return std::make_pair("", "");
}
#endif // LLAMA_USE_CURL
// //
// Batch utils // Batch utils
// //
@ -2036,26 +1557,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }
template <>
json common_grammar_trigger::to_json() const {
json out {
{"type", (int) type},
{"value", value},
};
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) token;
}
return out;
}
template <>
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
common_grammar_trigger out;
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
out.value = in.at("value").get<std::string>();
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out.token = (llama_token) in.at("token").get<int>();
}
return out;
}

View file

@ -117,10 +117,6 @@ struct common_grammar_trigger {
common_grammar_trigger_type type; common_grammar_trigger_type type;
std::string value; std::string value;
llama_token token = LLAMA_TOKEN_NULL; llama_token token = LLAMA_TOKEN_NULL;
// T can only be nlohmann::ordered_json
template <class T> T to_json() const;
template <class T> static common_grammar_trigger from_json(const T & in);
}; };
// sampling parameters // sampling parameters
@ -180,6 +176,13 @@ struct common_params_sampling {
std::string print() const; std::string print() const;
}; };
struct common_params_model {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
};
struct common_params_speculative { struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
@ -193,19 +196,11 @@ struct common_params_speculative {
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
std::string hf_repo = ""; // HF repo // NOLINT struct common_params_model model;
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
}; };
struct common_params_vocoder { struct common_params_vocoder {
std::string hf_repo = ""; // HF repo // NOLINT struct common_params_model model;
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string speaker_file = ""; // speaker file path // NOLINT std::string speaker_file = ""; // speaker file path // NOLINT
@ -263,12 +258,10 @@ struct common_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
struct common_params_vocoder vocoder; struct common_params_vocoder vocoder;
std::string model = ""; // model path // NOLINT struct common_params_model model;
std::string model_alias = ""; // model alias // NOLINT std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string prompt = ""; // NOLINT std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT
@ -282,6 +275,7 @@ struct common_params {
std::vector<std::string> in_files; // all input files std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
@ -343,7 +337,7 @@ struct common_params {
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see examples/llava) // multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector // NOLINT struct common_params_model mmproj;
std::vector<std::string> image; // path to image file(s) std::vector<std::string> image; // path to image file(s)
// embedding // embedding
@ -542,23 +536,6 @@ struct llama_model_params common_model_params_to_llama ( common_params
struct llama_context_params common_context_params_to_llama(const common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora); void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);

View file

@ -5146,10 +5146,7 @@ class BailingMoeModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
if hparams.get("head_dim"): rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"]
rope_dim = hparams["head_dim"]
else:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim) self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
@ -5175,7 +5172,7 @@ class BailingMoeModel(Model):
n_head = self.hparams["num_attention_heads"] n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
n_embd = self.hparams["hidden_size"] n_embd = self.hparams["hidden_size"]
head_dim = self.hparams.get("head_dim", n_embd // n_head) head_dim = self.hparams.get("head_dim") or n_embd // n_head
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)

View file

@ -4,6 +4,26 @@
> >
> This is very experimental, only used for demo purpose. > This is very experimental, only used for demo purpose.
## Quick started
You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)'s Hugging Face account
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
# alternatively, install from brew (MacOS)
brew install llama.cpp
# run it
llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
# note: 1B model does not support vision
```
## How to get mmproj.gguf? ## How to get mmproj.gguf?
```bash ```bash

View file

@ -78,7 +78,7 @@ struct gemma3_context {
} }
void init_clip_model(common_params & params) { void init_clip_model(common_params & params) {
const char * clip_path = params.mmproj.c_str(); const char * clip_path = params.mmproj.path.c_str();
ctx_clip = clip_model_load(clip_path, params.verbosity > 1); ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
} }
@ -232,13 +232,13 @@ int main(int argc, char ** argv) {
common_init(); common_init();
if (params.mmproj.empty()) { if (params.mmproj.path.empty()) {
show_additional_info(argc, argv); show_additional_info(argc, argv);
return 1; return 1;
} }
gemma3_context ctx(params); gemma3_context ctx(params);
printf("%s: %s\n", __func__, params.model.c_str()); printf("%s: %s\n", __func__, params.model.path.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty(); bool is_single_turn = !params.prompt.empty() && !params.image.empty();

View file

@ -225,7 +225,7 @@ static struct llama_model * llava_init(common_params * params) {
llama_model_params model_params = common_model_params_to_llama(*params); llama_model_params model_params = common_model_params_to_llama(*params);
llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n" , __func__); LOG_ERR("%s: unable to load model\n" , __func__);
return NULL; return NULL;
@ -234,7 +234,7 @@ static struct llama_model * llava_init(common_params * params) {
} }
static struct llava_context * llava_init_context(common_params * params, llama_model * model) { static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
const char * clip_path = params->mmproj.c_str(); const char * clip_path = params->mmproj.path.c_str();
auto prompt = params->prompt; auto prompt = params->prompt;
if (prompt.empty()) { if (prompt.empty()) {
@ -283,7 +283,7 @@ int main(int argc, char ** argv) {
common_init(); common_init();
if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
print_usage(argc, argv); print_usage(argc, argv);
return 1; return 1;
} }

View file

@ -31,7 +31,7 @@ static struct llama_model * llava_init(common_params * params) {
llama_model_params model_params = common_model_params_to_llama(*params); llama_model_params model_params = common_model_params_to_llama(*params);
llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n" , __func__); LOG_ERR("%s: unable to load model\n" , __func__);
return NULL; return NULL;
@ -80,7 +80,7 @@ static void llava_free(struct llava_context * ctx_llava) {
} }
static struct clip_ctx * clip_init_context(common_params * params) { static struct clip_ctx * clip_init_context(common_params * params) {
const char * clip_path = params->mmproj.c_str(); const char * clip_path = params->mmproj.path.c_str();
auto prompt = params->prompt; auto prompt = params->prompt;
if (prompt.empty()) { if (prompt.empty()) {
@ -290,7 +290,7 @@ int main(int argc, char ** argv) {
common_init(); common_init();
if (params.mmproj.empty() || (params.image.empty())) { if (params.mmproj.path.empty() || (params.image.empty())) {
show_additional_info(argc, argv); show_additional_info(argc, argv);
return 1; return 1;
} }

View file

@ -314,7 +314,7 @@ static struct llama_model * llava_init(common_params * params) {
llama_model_params model_params = common_model_params_to_llama(*params); llama_model_params model_params = common_model_params_to_llama(*params);
llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n" , __func__); LOG_ERR("%s: unable to load model\n" , __func__);
return NULL; return NULL;
@ -323,7 +323,7 @@ static struct llama_model * llava_init(common_params * params) {
} }
static struct llava_context * llava_init_context(common_params * params, llama_model * model) { static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
const char * clip_path = params->mmproj.c_str(); const char * clip_path = params->mmproj.path.c_str();
auto prompt = params->prompt; auto prompt = params->prompt;
if (prompt.empty()) { if (prompt.empty()) {
@ -524,7 +524,7 @@ int main(int argc, char ** argv) {
common_init(); common_init();
if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
print_usage(argc, argv); print_usage(argc, argv);
return 1; return 1;
} }

View file

@ -133,7 +133,8 @@ struct slot_params {
auto grammar_triggers = json::array(); auto grammar_triggers = json::array();
for (const auto & trigger : sampling.grammar_triggers) { for (const auto & trigger : sampling.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json<json>()); server_grammar_trigger ct(std::move(trigger));
grammar_triggers.push_back(ct.to_json());
} }
return json { return json {
@ -372,9 +373,9 @@ struct server_task {
const auto grammar_triggers = data.find("grammar_triggers"); const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) { if (grammar_triggers != data.end()) {
for (const auto & t : *grammar_triggers) { for (const auto & t : *grammar_triggers) {
auto ct = common_grammar_trigger::from_json(t); server_grammar_trigger ct(t);
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
const auto & word = ct.value; const auto & word = ct.value.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) { if (ids.size() == 1) {
auto token = ids[0]; auto token = ids[0];
@ -392,7 +393,7 @@ struct server_task {
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
} }
} else { } else {
params.sampling.grammar_triggers.push_back(ct); params.sampling.grammar_triggers.push_back(std::move(ct.value));
} }
} }
} }
@ -1876,7 +1877,7 @@ struct server_context {
} }
bool load_model(const common_params & params) { bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.c_str()); SRV_INF("loading model '%s'\n", params.model.path.c_str());
params_base = params; params_base = params;
@ -1886,7 +1887,7 @@ struct server_context {
ctx = llama_init.context.get(); ctx = llama_init.context.get();
if (model == nullptr) { if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
return false; return false;
} }
@ -1897,16 +1898,13 @@ struct server_context {
add_bos_token = llama_vocab_get_add_bos(vocab); add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
auto params_dft = params_base; auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices; params_dft.devices = params_base.speculative.devices;
params_dft.hf_file = params_base.speculative.hf_file;
params_dft.hf_repo = params_base.speculative.hf_repo;
params_dft.model = params_base.speculative.model; params_dft.model = params_base.speculative.model;
params_dft.model_url = params_base.speculative.model_url;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1; params_dft.n_parallel = 1;
@ -1920,12 +1918,12 @@ struct server_context {
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
if (model_dft == nullptr) { if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
return false; return false;
} }
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
return false; return false;
} }
@ -3865,7 +3863,7 @@ int main(int argc, char ** argv) {
json data = { json data = {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel }, { "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model }, { "model_path", ctx_server.params_base.model.path },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@ -4131,7 +4129,7 @@ int main(int argc, char ** argv) {
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
{ {
{"id", params.model_alias.empty() ? params.model : params.model_alias}, {"id", params.model_alias.empty() ? params.model.path : params.model_alias},
{"object", "model"}, {"object", "model"},
{"created", std::time(0)}, {"created", std::time(0)},
{"owned_by", "llamacpp"}, {"owned_by", "llamacpp"},

View file

@ -58,6 +58,32 @@ static T json_value(const json & body, const std::string & key, const T & defaul
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
// thin wrapper around common_grammar_trigger with (de)serialization functions
struct server_grammar_trigger {
common_grammar_trigger value;
server_grammar_trigger() = default;
server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
server_grammar_trigger(const json & in) {
value.type = (common_grammar_trigger_type) in.at("type").get<int>();
value.value = in.at("value").get<std::string>();
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
value.token = (llama_token) in.at("token").get<int>();
}
}
json to_json() const {
json out {
{"type", (int) value.type},
{"value", value.value},
};
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) value.token;
}
return out;
}
};
// //
// tokenizer and input processing utils // tokenizer and input processing utils
// //
@ -627,7 +653,8 @@ static json oaicompat_completion_params_parse(
llama_params["grammar_lazy"] = chat_params.grammar_lazy; llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array(); auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) { for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json<json>()); server_grammar_trigger ct(trigger);
grammar_triggers.push_back(ct.to_json());
} }
llama_params["grammar_triggers"] = grammar_triggers; llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens; llama_params["preserved_tokens"] = chat_params.preserved_tokens;

View file

@ -577,12 +577,7 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model_ttc); const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
// TODO: refactor in a common struct
params.model = params.vocoder.model; params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url;
params.hf_repo = params.vocoder.hf_repo;
params.hf_file = params.vocoder.hf_file;
params.embedding = true; params.embedding = true;
common_init_result llama_init_cts = common_init_from_params(params); common_init_result llama_init_cts = common_init_from_params(params);

View file

@ -1420,6 +1420,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
ggml_cann_argsort(ctx, dst); ggml_cann_argsort(ctx, dst);
break; break;
case GGML_OP_ARGMAX:
ggml_cann_argmax(ctx, dst);
break;
case GGML_OP_COS:
ggml_cann_cos(ctx, dst);
break;
case GGML_OP_SIN:
ggml_cann_sin(ctx, dst);
break;
default: default:
return false; return false;
} }
@ -1458,11 +1467,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeDevice()); ACL_CHECK(aclrtSynchronizeDevice());
ACL_CHECK(aclrtResetDevice(cann_ctx->device)); ACL_CHECK(aclrtResetDevice(cann_ctx->device));
// finalize when last backend freed.
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
ACL_CHECK(aclFinalize());
}
delete cann_ctx; delete cann_ctx;
delete backend; delete backend;
} }
@ -1688,11 +1692,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
} }
case GGML_OP_MUL_MAT: { case GGML_OP_MUL_MAT: {
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_Q8_0:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
return true; return true;
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
// only support contiguous for quantized types.
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]);
default: default:
return false; return false;
} }
@ -1704,7 +1711,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return true; return true;
default: default:
@ -1712,16 +1718,21 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
} }
} break; } break;
case GGML_OP_CPY: { case GGML_OP_CPY: {
switch (op->type) { ggml_tensor *src = op->src[0];
case GGML_TYPE_F32: if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
case GGML_TYPE_F16: (src->type != GGML_TYPE_F32 &&
case GGML_TYPE_Q8_0: src->type != GGML_TYPE_F16)) {
case GGML_TYPE_Q4_0: // only support F32 and F16.
return true;
default:
return false; return false;
} }
if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
// unsupport dst is not contiguous.
return false;
} }
return true;
} break;
case GGML_OP_CONT: { case GGML_OP_CONT: {
// TODO: support GGML_TYPE_BF16 // TODO: support GGML_TYPE_BF16
switch (op->src[0]->type) { switch (op->src[0]->type) {
@ -1734,13 +1745,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
} }
case GGML_OP_ROPE: { case GGML_OP_ROPE: {
// TODO: with ops-test v == 1 // TODO: with ops-test v == 1
float * ext_factor = (float*)((int32_t*)op->op_params + 7); float ext_factor = 0.0f;
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) { if (op->src[0]->ne[0] != op->op_params[1]) {
return false; return false;
} }
// TODO: ext_factor != 0 // TODO: ext_factor != 0
if (*ext_factor != 0) { if (ext_factor != 0) {
return false; return false;
} }
@ -1762,9 +1774,19 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
} }
return true; return true;
} }
case GGML_OP_POOL_2D: {
const int32_t * opts = (const int32_t *) op->op_params;
const int k0 = opts[1];
const int k1 = opts[2];
const int p0 = opts[5];
const int p1 = opts[6];
// value of paddingH should be at most half of kernelH
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_DUP:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_DUP:
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
@ -1781,7 +1803,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
case GGML_OP_ACC: case GGML_OP_ACC:
@ -1790,6 +1811,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_ARGMAX:
case GGML_OP_COS:
case GGML_OP_SIN:
return true; return true;
default: default:
return false; return false;

View file

@ -729,7 +729,13 @@ struct ggml_cuda_graph {
bool disable_due_to_failed_graph_capture = false; bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0; int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties; std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<char **> updated_kernel_arg; bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif #endif
}; };

View file

@ -32,16 +32,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) { const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets // then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02); const int64_t i03 = i/(ne00 * ne01 * ne02);
@ -288,16 +290,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) { const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -314,16 +318,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) { const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
const int i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
@ -339,66 +345,84 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
// Copy destination pointers to GPU to be available when pointer indirection is in use
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
CUDA_CHECK(cudaStreamSynchronize(stream));
if (cuda_graph->dest_ptrs_d != nullptr) {
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
}
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
}
// copy destination pointers to GPU
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
cuda_graph->graph_cpynode_index = 0; // reset index
#endif
}
static void ggml_cpy_f16_f32_cuda( static void ggml_cpy_f16_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_f32_cuda( static void ggml_cpy_f32_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_f16_cuda( static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_q8_0_f32_cuda( static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_q4_0_cuda( static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_q4_0_f32_cuda( static void ggml_cpy_q4_0_f32_cuda(
@ -407,22 +431,22 @@ static void ggml_cpy_q4_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_q4_1_cuda( static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_q4_1_f32_cuda( static void ggml_cpy_q4_1_f32_cuda(
@ -431,22 +455,22 @@ static void ggml_cpy_q4_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_q5_0_cuda( static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_q5_0_f32_cuda( static void ggml_cpy_q5_0_f32_cuda(
@ -455,22 +479,22 @@ static void ggml_cpy_q5_0_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_q5_1_cuda( static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_q5_1_f32_cuda( static void ggml_cpy_q5_1_f32_cuda(
@ -479,32 +503,32 @@ static void ggml_cpy_q5_1_f32_cuda(
const int nb00, const int nb01, const int nb02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = ne; const int num_blocks = ne;
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f32_iq4_nl_cuda( static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
static void ggml_cpy_f16_f16_cuda( static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
} }
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
@ -541,46 +565,60 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data; char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data; char * src1_ddc = (char *) src1->data;
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
}
#endif
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else { } else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));
} }
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
}
#endif
} }
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View file

@ -7,3 +7,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

View file

@ -2446,10 +2446,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->updated_kernel_arg.clear(); cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i]; ggml_tensor * node = cgraph->nodes[i];
@ -2481,8 +2482,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
} }
if (node->op == GGML_OP_CPY) { if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); // Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later // store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) { if (!ptr) {
@ -2490,10 +2494,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#ifndef NDEBUG #ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif #endif
} else {
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
}
} }
} }
@ -2502,6 +2502,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
} }
} }
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
}
return use_cuda_graph; return use_cuda_graph;
} }
@ -2556,51 +2562,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return true; return true;
} }
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
if (cuda_graph_update_required) {
// Extract nodes from graph
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
// Loop over nodes, and extract kernel parameters from each node
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
cudaGraphNodeType node_type;
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
if (node_type == cudaGraphNodeTypeKernel) {
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
if (stat == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
(void)cudaGetLastError();
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
}
}
} else {
// One of the arguments to the copy kernel is updated for each token, hence we need to
// replace that argument with the updated value in the CUDA graph
// on update steps, the live parameters will already be captured
int k = 0;
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
*(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
}
}
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false; bool cuda_graph_update_required = false;
@ -2660,8 +2621,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#endif #endif
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
[[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
bool & cuda_graph_update_required) {
while (!graph_evaluated_or_captured) { while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@ -2711,13 +2671,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} }
if (cuda_graph_update_required) { // Update graph executable
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
// Update graph executable
update_cuda_graph_executable(cuda_ctx); update_cuda_graph_executable(cuda_ctx);
}
// Launch graph // Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
#else #else
@ -2731,10 +2687,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
ggml_cuda_set_device(cuda_ctx->device); ggml_cuda_set_device(cuda_ctx->device);
// vector of pointers to CUDA cpy kernels, which are required to identify
// kernel parameters which need updated in the graph for each token
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
@ -2768,8 +2720,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) { if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) { if (use_cuda_graph && cuda_graph_update_required) {
@ -2790,6 +2741,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
} }
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
}
#else #else
bool use_cuda_graph = false; bool use_cuda_graph = false;
bool cuda_graph_update_required = false; bool cuda_graph_update_required = false;
@ -2797,7 +2752,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
bool graph_evaluated_or_captured = false; bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }

View file

@ -4,13 +4,14 @@ template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int nc, const int ncs, const int nr, const int n_t, const int n_s) { const int64_t n_t) {
GGML_UNUSED(src0_nb0);
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int bidx = blockIdx.x; const int bidx = blockIdx.x;
const int bidy = blockIdx.y; const int bidy = blockIdx.y;
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
const int stride_x = src0_nb1 / sizeof(float); const int stride_x = src0_nb1 / sizeof(float);
@ -21,15 +22,15 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
float w[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f };
#pragma unroll #pragma unroll
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j]; w[j] = w_block[tid * stride_w + j];
} }
for (int i = 0; i < n_t; i++) { for (int64_t i = 0; i < n_t; i++) {
float sumf = 0.0f; float sumf = 0.0f;
if (i == 0) { if (i == 0) {
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j]; x[j] = x_block[tid * stride_x + j];
} }
} else { } else {
@ -37,27 +38,26 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
} }
#pragma unroll #pragma unroll
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j]; sumf += x[(i + j) % d_conv] * w[j];
} }
y_block[i * stride_y + tid] = sumf; y_block[i * stride_y + tid] = sumf;
} }
} }
template <size_t split_d_inner, size_t d_conv, size_t split_n_t> template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int nc, const int ncs, const int dst_nb1, const int dst_nb2, const int64_t n_t) {
const int nr, const int n_t, const int n_s) {
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int bidx = blockIdx.x; const int bidx = blockIdx.x;
const int bidy = blockIdx.y; const int bidy = blockIdx.y;
const int bidz = blockIdx.z; const int bidz = blockIdx.z;
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0); bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1); const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
@ -69,17 +69,17 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
float w[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f };
#pragma unroll #pragma unroll
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j]; w[j] = w_block[tid * stride_w + j];
} }
#pragma unroll #pragma unroll
for (int i = 0; i < split_n_t; i++) { for (int64_t i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) { if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f; float sumf = 0.0f;
if (i == 0) { if (i == 0) {
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j]; x[j] = x_block[tid * stride_x + j];
} }
} else { } else {
@ -87,7 +87,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
} }
#pragma unroll #pragma unroll
for (int j = 0; j < d_conv; j++) { for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j]; sumf += x[(i + j) % d_conv] * w[j];
} }
y_block[i * stride_y + tid] = sumf; y_block[i * stride_y + tid] = sumf;
@ -97,8 +97,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
const int n_s, cudaStream_t stream) { const int64_t n_s, cudaStream_t stream) {
const int threads = 128; const int threads = 128;
GGML_ASSERT(nr % threads == 0); GGML_ASSERT(nr % threads == 0);
@ -106,18 +106,16 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) { if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
n_s);
} else { } else {
GGML_ABORT("Only support kernel size = 4 now."); GGML_ABORT("Only support kernel size = 4 now.");
} }
} else { } else {
if (nc == 4) { if (nc == 4) {
const int split_n_t = 32; const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t> ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
} else { } else {
GGML_ABORT("Only support kernel size = 4 right now."); GGML_ABORT("Only support kernel size = 4 right now.");
} }
@ -128,11 +126,10 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const int nc = src1->ne[0]; // d_conv const int64_t nc = src1->ne[0]; // d_conv
const int ncs = src0->ne[0]; // d_conv - 1 + n_t const int64_t nr = src0->ne[1]; // d_inner
const int nr = src0->ne[1]; // d_inner const int64_t n_t = dst->ne[1]; // tokens per sequence
const int n_t = dst->ne[1]; // tokens per sequence const int64_t n_s = dst->ne[2]; // number of sequences in the batch
const int n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT(dst->ne[0] == nr); GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
@ -147,5 +144,5 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, ncs, nr, n_t, n_s, stream); dst->nb[2], nc, nr, n_t, n_s, stream);
} }

View file

@ -1,10 +1,5 @@
#include "ssm-scan.cuh" #include "ssm-scan.cuh"
// #include <cuda_runtime.h>
// static __device__ void global_to_shared(const float *src, float *dst) {
// asm volatile("cp.async.");
// }
template <size_t splitD, size_t N> template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2) __global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@ -12,7 +7,9 @@ __global__ void __launch_bounds__(splitD, 2)
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int D, const int L, const int B) { float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0);
const int bidx = blockIdx.x; // split along B const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -25,12 +22,12 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem; float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA; float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2)); const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2)); const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
@ -46,7 +43,7 @@ __global__ void __launch_bounds__(splitD, 2)
// can N not be 16? for example 32? // can N not be 16? for example 32?
if (N == 16) { if (N == 16) {
#pragma unroll #pragma unroll
for (int i = 0; i < splitD / 4; i += 2) { for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid]; float value = A_block[(wid * warpSize + i) * stride_A + wtid];
// todo: bank conflict // todo: bank conflict
// I am always confused with how to use the swizzling method to solve // I am always confused with how to use the swizzling method to solve
@ -54,7 +51,7 @@ __global__ void __launch_bounds__(splitD, 2)
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
} }
#pragma unroll #pragma unroll
for (int i = 0; i < splitD / 4; i += 2) { for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
} }
@ -62,7 +59,7 @@ __global__ void __launch_bounds__(splitD, 2)
__syncthreads(); __syncthreads();
for (int i = 0; i < L; i++) { for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid]; float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) { if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus)); dt_soft_plus = log1pf(exp(dt_soft_plus));
@ -70,7 +67,7 @@ __global__ void __launch_bounds__(splitD, 2)
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus; float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
#pragma unroll #pragma unroll
for (int j = 0; j < N; j++) { for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) + float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt); (B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j]; sumf += state * C_block[i * stride_C + j];
@ -90,7 +87,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) { float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
cudaStream_t stream) {
const int threads = 128; const int threads = 128;
// todo: consider D cannot be divided,does this situation exist? // todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0); GGML_ASSERT(D % threads == 0);
@ -99,7 +97,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
if (N == 16) { if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B); src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
} else { } else {
GGML_ABORT("doesn't support N!=16."); GGML_ABORT("doesn't support N!=16.");
} }

View file

@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
// ne00*(nsg) // ne00*(nsg)
// each simdgroup has a full f16 head vector in shared mem to accumulate results // each simdgroup has a full f16 head vector in shared mem to accumulate results
// //
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;
while (true) { while (true) {

View file

@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
const bool has_mask = mask != q; const bool has_mask = mask != q;
half slope = 1.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
if (has_mask) { if (has_mask) {
// used to detect blocks full of -INF // used to detect blocks full of -INF
half smax = -INFINITY; float smax = -INFINITY;
// load the mask in shared memory // load the mask in shared memory
#pragma unroll(Q) #pragma unroll(Q)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
const half m = pm[ic + tiisg]; const float m = pm[ic + tiisg];
ss[j*TS + C + tiisg] = m; ss[j*TS + C + tiisg] = m;
smax = max(smax, m); smax = max(smax, m);
@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
// online softmax // online softmax
{ {
for (ushort j = 0; j < Q; ++j) { for (ushort j = 0; j < Q; ++j) {
const half m = M[j]; const float m = M[j];
// scale and apply the logitcap / mask // scale and apply the logitcap / mask
half s = ss[j*TS + tiisg]*args.scale; float s = ss[j*TS + tiisg]*args.scale;
if (args.logit_softcap != 0.0f) { if (args.logit_softcap != 0.0f) {
s = args.logit_softcap*precise::tanh(s); s = args.logit_softcap*precise::tanh(s);
@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
M[j] = simd_max(max(M[j], s)); M[j] = simd_max(max(M[j], s));
const half ms = exp(m - M[j]); const float ms = exp(m - M[j]);
const half vs = exp(s - M[j]); const float vs = exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs); S[j] = S[j]*ms + simd_sum(vs);
@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
half S = { 0.0f }; float S = { 0.0f };
half M = { -__FLT16_MAX__/2 }; float M = { -__FLT16_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
// the first simdgroup accumulates the results from the other simdgroups // the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const half S0 = ss[j*TS + 0]; const float S0 = ss[j*TS + 0];
const half S1 = ss[j*TS + sg*SH + 0]; const float S1 = ss[j*TS + sg*SH + 0];
const half M0 = ss[j*TS + 1]; const float M0 = ss[j*TS + 1];
const half M1 = ss[j*TS + sg*SH + 1]; const float M1 = ss[j*TS + sg*SH + 1];
M = max(M0, M1); M = max(M0, M1);
const half ms0 = exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M); const float ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;
@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
constexpr short DV4 = DV/4; constexpr short DV4 = DV/4;
constexpr short NW = N_SIMDWIDTH; constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
constexpr short SH = 2*C; // shared memory per simdgroup constexpr short SH = 4*C; // shared memory per simdgroup
const short T = DK + nsg*SH; // shared memory size per query in (half) const short T = DK + nsg*SH; // shared memory size per query in (half)
@ -3654,7 +3654,7 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
// store the result for all queries in local memory (the O matrix from the paper) // store the result for all queries in local memory (the O matrix from the paper)
@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S = 0.0f; float S = 0.0f;
half M = -__FLT16_MAX__/2; float M = -__FLT16_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
// pointer to the mask // pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31); device const half * pm = (device const half *) (mask + iq1*args.nb31);
half slope = 1.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
// online softmax // online softmax
{ {
const half m = M; const float m = M;
const half s = ss[tiisg]; const float s = ss[tiisg];
M = simd_max(max(M, s)); M = simd_max(max(M, s));
const half ms = exp(m - M); const float ms = exp(m - M);
const half vs = exp(s - M); const float vs = exp(s - M);
S = S*ms + simd_sum(vs); S = S*ms + simd_sum(vs);
@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
v4_t mv; v4_t mv;
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
lo[ii/NL] += mv*ms; lo[ii/NL] += o4_t(float4(mv)*float4(ms));
} }
} }
} }
@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
// parallel reduce // parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) { for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) { if (sgitg < r) {
const half S0 = ss[ 0]; const float S0 = ss[ 0];
const half S1 = ss[r*SH + 0]; const float S1 = ss[r*(SH/2) + 0];
const half M0 = ss[ 1]; const float M0 = ss[ 1];
const half M1 = ss[r*SH + 1]; const float M1 = ss[r*(SH/2) + 1];
const half M = max(M0, M1); const float M = max(M0, M1);
const half ms0 = exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M); const float ms1 = exp(M1 - M);
const half S = S0*ms0 + S1*ms1; const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) { if (tiisg == 0) {
ss[0] = S; ss[0] = S;
@ -3954,7 +3954,7 @@ kernel void kernel_flash_attn_ext_vec(
half4, \ half4, \
half4, \ half4, \
float, \ float, \
half, half4, \ float, float4, \
half4 half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;

View file

@ -921,10 +921,30 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts);
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
// TODO: fixme: these sizes are hardcoded for now.
// they should be allocated based on the model's size
// and the device's max alloc size
// Allocate intermediate buffers and images // Allocate intermediate buffers and images
size_t max_A_q_d_bytes = 311164928; size_t required_A_q_d_bytes = 311164928;
size_t max_A_s_d_bytes = 38895616; size_t required_A_s_d_bytes = 38895616;
size_t max_B_d_bytes = 45088768; size_t required_B_d_bytes = 45088768;
// Ensure buffer sizes do not exceed the maximum allocation size
size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size);
size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size);
size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size);
if (required_A_q_d_bytes > backend_ctx->max_alloc_size) {
GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n",
required_A_q_d_bytes, max_A_q_d_bytes);
}
if (required_A_s_d_bytes > backend_ctx->max_alloc_size) {
GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n",
required_A_s_d_bytes, max_A_s_d_bytes);
}
if (required_B_d_bytes > backend_ctx->max_alloc_size) {
GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n",
required_B_d_bytes, max_B_d_bytes);
}
CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err)); CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err)); CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));

View file

@ -38,6 +38,7 @@
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_AMD 0x1002
#define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_APPLE 0x106b
@ -359,6 +360,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_split_k_reduce;
std::unordered_map<std::string, vk_pipeline_ref> pipelines; std::unordered_map<std::string, vk_pipeline_ref> pipelines;
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements; std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@ -508,6 +510,10 @@ struct vk_flash_attn_push_constants {
uint32_t n_head_log2; uint32_t n_head_log2;
float m0; float m0;
float m1; float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
}; };
struct vk_op_push_constants { struct vk_op_push_constants {
@ -1480,7 +1486,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
// small rows, large cols // small rows, large cols
if (small_rows) { if (small_rows) {
return {flash_attention_num_small_rows, 128}; return {flash_attention_num_small_rows, 64};
} }
// small cols to reduce register count // small cols to reduce register count
if (ggml_is_quantized(type) || D == 256) { if (ggml_is_quantized(type) || D == 256) {
@ -2336,6 +2342,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@ -5417,7 +5424,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t nbm1 = mask ? mask->nb[1] : 0;
const uint32_t D = neq0; const uint32_t D = neq0;
const uint32_t N = neq1; uint32_t N = neq1;
const uint32_t KV = nek1; const uint32_t KV = nek1;
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == D);
@ -5475,9 +5482,54 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline pipeline = pipelines[aligned]; vk_pipeline pipeline = pipelines[aligned];
assert(pipeline); assert(pipeline);
uint32_t gqa_ratio = 1;
uint32_t qk_ratio = neq2 / nek2;
uint32_t workgroups_x = (uint32_t)neq1;
uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3;
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= N;
}
uint32_t split_kv = KV;
uint32_t split_k = 1;
if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
GGML_ASSERT(workgroups_x == 1);
// Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
}
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows).
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large");
}
if (ctx->prealloc_size_split_k < split_k_size) {
ctx->prealloc_size_split_k = split_k_size;
}
if (dryrun) { if (dryrun) {
// Request descriptor sets // Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
}
return; return;
} }
@ -5498,8 +5550,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_sync_buffers(subctx);
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
@ -5564,7 +5614,35 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride, (uint32_t)nbv2, (uint32_t)nbv3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1, nbm1,
scale, max_bias, logit_softcap, scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1 }; mask != nullptr, n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx);
if (split_k > 1) {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
},
// We only use split_k when group query attention is enabled, which means
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@ -5573,7 +5651,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
}, },
sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
}
} }
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {

View file

@ -61,6 +61,10 @@ layout (push_constant) uniform parameter {
uint32_t n_head_log2; uint32_t n_head_log2;
float m0; float m0;
float m1; float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p; } p;
layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
#define DECODEFUNC #define DECODEFUNC
#endif #endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < D) {
uint32_t offset = (iq2 + r) * D + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
void main() { void main() {
#ifdef NEEDS_INIT_IQ_SHMEM #ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize); init_iq_shmem(gl_WorkGroupSize);
@ -111,12 +147,22 @@ void main() {
const uint32_t N = p.N; const uint32_t N = p.N;
const uint32_t KV = p.KV; const uint32_t KV = p.KV;
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br); const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t Tc = CEIL_DIV(KV, Bc);
const uint32_t i = gl_WorkGroupID.x; const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
const uint32_t iq2 = gl_WorkGroupID.y; // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z; const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors // broadcast factors
@ -149,8 +195,10 @@ void main() {
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements // nb?1 are already divided by the type size and are in units of elements.
uint32_t q_stride = p.nb01; // When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11; uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21; uint32_t v_stride = p.nb21;
// hint to the compiler that strides are aligned for the aligned variant of the shader // hint to the compiler that strides are aligned for the aligned variant of the shader
@ -182,20 +230,15 @@ void main() {
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
ACC_TYPE slope = ACC_TYPE(1.0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi // ALiBi
if (p.max_bias > 0.0f) { if (p.max_bias > 0.0f) {
const uint32_t h = iq2; coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
slope = pow(base, ACC_TYPE(exph));
} }
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = 0; j < Tc; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
@ -215,12 +258,16 @@ void main() {
if (p.mask != 0) { if (p.mask != 0) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
// When using grouped query attention, all rows use the same mask.
if (p.gqa_ratio > 1) {
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} }
// Clear padding elements to -inf, so they don't contribute to rowmax // Clear padding elements to -inf, so they don't contribute to rowmax
@ -285,6 +332,20 @@ void main() {
O = coopMatMulAdd(P_A, V, O); O = coopMatMulAdd(P_A, V, O);
} }
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * split_k_index;
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce // resize L by using smear/reduce
@ -297,13 +358,18 @@ void main() {
O = Ldiag*O; O = Ldiag*O;
uint32_t o_offset = iq3*p.ne2*p.ne1;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
// permute dimensions // permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
uint32_t o_offset = iq3*p.ne2*p.ne1;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); }
} }

View file

@ -0,0 +1,59 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 32
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint k_num;
} p;
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * k_num + n;
uint m_offset = D * N * k_num + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float m = data_a[m_offset + k * lm_stride];
m_max = max(m_max, m);
}
// Compute L based on m_max
float L = 0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float l = data_a[l_offset + k * lm_stride];
float m = data_a[m_offset + k * lm_stride];
L += exp(m - m_max) * l;
}
L = 1.0 / L;
// Scale and sum the O contributions based on m_max and store the result to memory
for (uint d = tid; d < D; d += BLOCK_SIZE) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
O *= L;
data_d[D * n + d] = O;
}
}

View file

@ -234,9 +234,9 @@ void main() {
#endif #endif
#if QUANT_AUXF == 1 #if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[TM]; FLOAT_TYPE cache_a_dm[WMITER * TM];
#else #else
FLOAT_TYPE_VEC2 cache_a_dm[TM]; FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif #endif
FLOAT_TYPE_VEC2 cache_b_ds[TN]; FLOAT_TYPE_VEC2 cache_b_ds[TN];
@ -247,7 +247,6 @@ void main() {
const uint iqs = loadr_a; const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l; const uint buf_ib = loadc_a + l;
// Should ds be gated to a single thread?
if (iqs == 0) { if (iqs == 0) {
#if QUANT_AUXF == 1 #if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib); buf_a_dm[buf_ib] = get_d(ib);
@ -276,7 +275,6 @@ void main() {
const uint buf_ib = loadc_b + l; const uint buf_ib = loadc_b + l;
// Should ds be gated to a single thread?
if (iqs == 0) { if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
} }

View file

@ -17,7 +17,7 @@ i32vec2 repack(uint ib, uint iqs) {
} }
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y)); return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
} }
#endif #endif
@ -51,7 +51,7 @@ i32vec2 repack(uint ib, uint iqs) {
} }
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y)); return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
} }
#endif #endif

View file

@ -474,6 +474,7 @@ void process_shaders() {
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

View file

@ -1172,6 +1172,12 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
} }
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
if (tensor->ne[i] <= 0) {
return 0;
}
}
size_t nbytes; size_t nbytes;
const size_t blck_size = ggml_blck_size(tensor->type); const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) { if (blck_size == 1) {

View file

@ -282,10 +282,18 @@ extern "C" {
}; };
}; };
struct llama_model_tensor_buft_override {
const char * pattern;
ggml_backend_buffer_type_t buft;
};
struct llama_model_params { struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices; ggml_backend_dev_t * devices;
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
int32_t n_gpu_layers; // number of layers to store in VRAM int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs enum llama_split_mode split_mode; // how to split the model across multiple GPUs

View file

@ -3815,7 +3815,7 @@ Current version indicated by LITEVER below.
document.getElementById("lastreq1").innerHTML = document.getElementById("lastreq1").innerHTML =
document.getElementById("lastreq2").innerHTML = document.getElementById("lastreq2").innerHTML =
document.getElementById("lastreq3").innerHTML = document.getElementById("lastreq3").innerHTML =
`<a href="#" class="color_grayurl" onclick="msgbox('Source code is available at https://github.com/LostRuins/lite.koboldai.net \\nPlease report any bugs you find there.','Information')">KoboldAI Lite</a> v${LITEVER} Web - Frontend for <a href="#" class="color_grayurl" onclick="msgbox('KoboldAI Lite allows you to connect to various third-party AI services. We do not control or assume responsibility for the models or content generated by these services. The user is responsible for ensuring that their usage of this software is legal in their country, and complies with the terms of service of the service they are connected to. Use at your own discretion.','Disclaimer')">External API Services</a>`; `<a href="#" class="color_grayurl mainnav" title="KoboldAI Lite Information" onclick="msgbox('Source code is available at https://github.com/LostRuins/lite.koboldai.net \\nPlease report any bugs you find there.','Information')">KoboldAI Lite</a> v${LITEVER} Web - Frontend for <a href="#" class="color_grayurl mainnav" title="KoboldAI Lite Disclaimer" onclick="msgbox('KoboldAI Lite allows you to connect to various third-party AI services. We do not control or assume responsibility for the models or content generated by these services. The user is responsible for ensuring that their usage of this software is legal in their country, and complies with the terms of service of the service they are connected to. Use at your own discretion.','Disclaimer')">External API Services</a>`;
trigger_abort_controller(); //first trigger sets it up trigger_abort_controller(); //first trigger sets it up
@ -14987,14 +14987,26 @@ Current version indicated by LITEVER below.
} }
urlbase = default_gemini_base + mdlname + default_gemini_suffix; urlbase = default_gemini_base + mdlname + default_gemini_suffix;
let geminiparts = [];
if (insertAIVisionImages.length > 0) {
for (let i = 0; i < insertAIVisionImages.length; ++i) {
let oaiimg = {
"inline_data": {
"mime_type": "image/jpeg",
"data": insertAIVisionImages[i]
}
};
geminiparts.push(oaiimg);
}
}
geminiparts.push({"text": submit_payload.prompt});
let payload = { let payload = {
"contents": [ "contents": [
{ {
"parts": [ "parts": geminiparts
{
"text": submit_payload.prompt
}
]
} }
], ],
"safetySettings": [ "safetySettings": [
@ -15028,6 +15040,12 @@ Current version indicated by LITEVER below.
"stopSequences": [] "stopSequences": []
} }
}; };
if(document.getElementById("usegeminiweb").checked)
{
payload["tools"] = [{"google_search": {}}];
}
let sentrole = document.getElementById("geminiroledropdown").value; let sentrole = document.getElementById("geminiroledropdown").value;
if(sentrole!="") if(sentrole!="")
{ {
@ -15675,7 +15693,8 @@ Current version indicated by LITEVER below.
if(savedmeta.visionmode==4) if(savedmeta.visionmode==4)
{ {
let isoai = (custom_oai_key!="" && document.getElementById("useoaichatcompl").checked); let isoai = (custom_oai_key!="" && document.getElementById("useoaichatcompl").checked);
visionstatus = isoai?`<span class="color_green">OpenAI API (Conditional)</span>`:`<span class="color_yellow">Unsupported</span>`; let isgemini = (custom_gemini_key!="");
visionstatus = (isoai?`<span class="color_green">OpenAI API (Conditional)</span>`:(isgemini?`<span class="color_green">Gemini API (Conditional)</span>`:`<span class="color_yellow">Unsupported</span>`));
} }
else if(savedmeta.visionmode==3) else if(savedmeta.visionmode==3)
{ {
@ -22469,17 +22488,17 @@ Current version indicated by LITEVER below.
<span style="float:right;"> <span style="float:right;">
<a href="#" class="color_green" onclick="get_and_show_workers()">[See Current Volunteers] </a> <a href="#" class="color_green" onclick="get_and_show_workers()">[See Current Volunteers] </a>
</span> </span>
<select class="form-control" id="pickedmodel" size="7" multiple></select> <select title="AI Horde Target Selection" class="form-control" id="pickedmodel" size="7" multiple></select>
</div> </div>
<div class="menutext" style="text-align: left;"> <div class="menutext" style="text-align: left;">
Select By Worker <span class="helpicon">? Select By Worker <span class="helpicon">?
<span class="helptext">This option explicitly assigns worker IDs, fixed based on the current workers available at model selection time.</span> <span class="helptext">This option explicitly assigns worker IDs, fixed based on the current workers available at model selection time.</span>
</span> </span>
<input type="checkbox" id="manualworker" onclick="toggle_manual_horde_worker()"> <input title="Select by Worker" type="checkbox" id="manualworker" onclick="toggle_manual_horde_worker()">
<span style="float:right;"> <span style="float:right;">
<input class="settinglabel miniinput" style="margin: 3px; width: 90px;" type="text" placeholder="Quick Search" value="" id="modelquicksearch" oninput="model_quick_search()"> <input title="Quick Search" class="settinglabel miniinput" style="margin: 3px; width: 90px;" type="text" placeholder="Quick Search" value="" id="modelquicksearch" oninput="model_quick_search()">
</span> </span>
</div> </div>
@ -22489,8 +22508,8 @@ Current version indicated by LITEVER below.
You can use this to connect to a KoboldAI instance running via a remote tunnel such as <span class="color_orange" style="font-weight: bold;">trycloudflare, localtunnel, ngrok</span>.<br><br> You can use this to connect to a KoboldAI instance running via a remote tunnel such as <span class="color_orange" style="font-weight: bold;">trycloudflare, localtunnel, ngrok</span>.<br><br>
Localhost IPs require host mode enabled. You can use the remote address displayed in the <span class="color_orange" style="font-weight: bold;">terminal console</span> or <span class="color_orange" style="font-weight: bold;">colab window</span>, note that the model must be loaded first.<br><br> Localhost IPs require host mode enabled. You can use the remote address displayed in the <span class="color_orange" style="font-weight: bold;">terminal console</span> or <span class="color_orange" style="font-weight: bold;">colab window</span>, note that the model must be loaded first.<br><br>
<span class="color_green" style="font-weight: bold;">Please input URL of the KoboldAI instance.</span><br><br> <span class="color_green" style="font-weight: bold;">Please input URL of the KoboldAI instance.</span><br><br>
<input class="form-control" id="customkoboldendpoint" placeholder="https://sample-remote-address.trycloudflare.com" value=""> <input class="form-control" title="Enter KoboldCpp Custom Endpoint" id="customkoboldendpoint" placeholder="https://sample-remote-address.trycloudflare.com" value="">
<input class="form-control" type="password" id="customkoboldkey" placeholder="KoboldAI API Key (Optional)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br> <input class="form-control" title="Enter KoboldCpp API Key" type="password" id="customkoboldkey" placeholder="KoboldAI API Key (Optional)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
<div class="borderbox flex flex-push-right"> <div class="borderbox flex flex-push-right">
<input type="checkbox" id="remoteconsolelog"> <input type="checkbox" id="remoteconsolelog">
<div class="box-label" title="Will display outputs to the remote endpoint's console logs, useful for debugging.">Show Console Logging</div> <div class="box-label" title="Will display outputs to the remote endpoint's console logs, useful for debugging.">Show Console Logging</div>
@ -22526,7 +22545,7 @@ Current version indicated by LITEVER below.
<input class="form-control" type="text" id="custom_oai_endpoint" placeholder="OpenAI API URL" value="" onblur="try_fetch_oai_models_auto()"> <input class="form-control" type="text" id="custom_oai_endpoint" placeholder="OpenAI API URL" value="" onblur="try_fetch_oai_models_auto()">
<input class="form-control" type="password" id="custom_oai_key" placeholder="API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br> <input class="form-control" type="password" id="custom_oai_key" placeholder="API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
Model Choice:<br> Model Choice:<br>
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control" id="custom_oai_model" onchange="oai_model_change(true)"> <select title="OpenAI Model Selection" style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control" id="custom_oai_model" onchange="oai_model_change(true)">
<option value="gpt-3.5-turbo-instruct" selected="selected">gpt-3.5-turbo-instruct</option> <option value="gpt-3.5-turbo-instruct" selected="selected">gpt-3.5-turbo-instruct</option>
<option value="davinci-002">davinci-002</option> <option value="davinci-002">davinci-002</option>
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option> <option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
@ -22539,7 +22558,7 @@ Current version indicated by LITEVER below.
<option value="o1-preview">o1-preview</option> <option value="o1-preview">o1-preview</option>
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option> <option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
</select> </select>
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_openrouter_model" onchange="oai_model_change(true)"> <select title="OpenRouter AI Model Selection" style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_openrouter_model" onchange="oai_model_change(true)">
<option value="openai/gpt-3.5-turbo">openai/gpt-3.5-turbo</option> <option value="openai/gpt-3.5-turbo">openai/gpt-3.5-turbo</option>
<option value="openai/gpt-4">openai/gpt-4</option> <option value="openai/gpt-4">openai/gpt-4</option>
<option value="openai/gpt-3.5-turbo-instruct">openai/gpt-3.5-turbo-instruct</option> <option value="openai/gpt-3.5-turbo-instruct">openai/gpt-3.5-turbo-instruct</option>
@ -22549,7 +22568,7 @@ Current version indicated by LITEVER below.
<option value="anthropic/claude-2.0">anthropic/claude-2.0</option> <option value="anthropic/claude-2.0">anthropic/claude-2.0</option>
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option> <option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
</select> </select>
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_mistralai_model" onchange="oai_model_change(true)"> <select title="Mistral AI Model Selection" style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_mistralai_model" onchange="oai_model_change(true)">
<option value="open-mistral-7b">open-mistral-7b</option> <option value="open-mistral-7b">open-mistral-7b</option>
<option value="open-mistral-nemo">open-mistral-nemo</option> <option value="open-mistral-nemo">open-mistral-nemo</option>
<option value="open-mixtral-8x22b">open-mixtral-8x22b</option> <option value="open-mixtral-8x22b">open-mixtral-8x22b</option>
@ -22562,7 +22581,7 @@ Current version indicated by LITEVER below.
<option value="codestral-latest">codestral-latest</option> <option value="codestral-latest">codestral-latest</option>
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option> <option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
</select> </select>
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_featherless_model" onchange="oai_model_change(true)"> <select title="Featherless AI Model Selection" style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_featherless_model" onchange="oai_model_change(true)">
<option value="Sao10K/L3-8B-Lunaris-v1">Sao10K/L3-8B-Lunaris-v1</option> <option value="Sao10K/L3-8B-Lunaris-v1">Sao10K/L3-8B-Lunaris-v1</option>
<option value="Sao10K/L3-8B-Stheno-v3.2">Sao10K/L3-8B-Stheno-v3.2</option> <option value="Sao10K/L3-8B-Stheno-v3.2">Sao10K/L3-8B-Stheno-v3.2</option>
<option value="unsloth/llama-3-8b-Instruct">unsloth/llama-3-8b-Instruct</option> <option value="unsloth/llama-3-8b-Instruct">unsloth/llama-3-8b-Instruct</option>
@ -22576,7 +22595,7 @@ Current version indicated by LITEVER below.
<option value="meta-llama/Meta-Llama-3.1-405B-Instruct">meta-llama/Meta-Llama-3.1-405B-Instruct</option> <option value="meta-llama/Meta-Llama-3.1-405B-Instruct">meta-llama/Meta-Llama-3.1-405B-Instruct</option>
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option> <option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
</select> </select>
<select style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_grok_model" onchange="oai_model_change(true)"> <select title="Grok AI Model Selection" style="padding:4px;display:inline;width:calc(100% - 220px)" class="form-control hidden" id="custom_grok_model" onchange="oai_model_change(true)">
<option value="grok-beta">grok-beta</option> <option value="grok-beta">grok-beta</option>
<option style="display:none;" class="custom_model_option" value="custom">[Custom]</option> <option style="display:none;" class="custom_model_option" value="custom">[Custom]</option>
</select> </select>
@ -22599,38 +22618,38 @@ Current version indicated by LITEVER below.
<span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();"> <span id="useoaichatcomplbox" class="hidden" onload="toggleoaichatcompl();">
<br> <br>
Main Message Role: Main Message Role:
<select class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="oairoledropdown"> <select title="Main Message Role" class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="oairoledropdown">
<option value="0" selected>User</option> <option value="0" selected>User</option>
<option value="1">Assistant</option> <option value="1">Assistant</option>
<option value="2">System</option> <option value="2">System</option>
<option value="3">AutoRole</option> <option value="3">AutoRole</option>
</select> </select>
<input type="checkbox" id="jailbreakprompt" onchange="togglejailbreak()"> <input type="checkbox" title="Add Prefix Prompt" id="jailbreakprompt" onchange="togglejailbreak()">
<div class="box-label" title="Adds extra text at the start to improve AI response">Add Prefix</div> <div class="box-label" title="Add Prefix. Forcefully inserts extra text at the start of the prompt to steer the AI.">Add Prefix</div>
<input type="checkbox" id="jailbreakprompt2" onchange="togglejailbreak2()"> <input type="checkbox" title="Add Postfix Prompt" id="jailbreakprompt2" onchange="togglejailbreak2()">
<div class="box-label" title="Adds extra text to the end to improve AI response">Add Postfix</div> <div class="box-label" title="Add Postfix. Forcefully inserts extra text to before the AI response to steer the AI.">Add Postfix</div>
<div style="display:flex" id="oaijailbreakpromptblock1"> <div style="display:flex" id="oaijailbreakpromptblock1">
<select class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="jailbreakprompttextrole"> <select title="Injected Prefix Message Role" class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="jailbreakprompttextrole">
<option value="0">User</option> <option value="0">User</option>
<option value="1">Assistant</option> <option value="1">Assistant</option>
<option value="2" selected>System</option> <option value="2" selected>System</option>
</select> </select>
<textarea class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="jailbreakprompttext" placeholder="(Enter System Prefix)" <textarea title="Enter Prefix Prompt String" class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="jailbreakprompttext" placeholder="(Enter System Prefix)"
value="" onload="togglejailbreak();"></textarea> value="" onload="togglejailbreak();"></textarea>
</div> </div>
<div style="display:flex" id="oaijailbreakpromptblock2"> <div style="display:flex" id="oaijailbreakpromptblock2">
<select class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="jailbreakprompttext2role"> <select title="Injected Postfix Message Role" class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="jailbreakprompttext2role">
<option value="0">User</option> <option value="0">User</option>
<option value="1" selected>Assistant</option> <option value="1" selected>Assistant</option>
<option value="2">System</option> <option value="2">System</option>
</select> </select>
<textarea class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%;" type="text" id="jailbreakprompttext2" placeholder="(Enter Assistant Postfix)" <textarea title="Enter Postfix Prompt String" class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%;" type="text" id="jailbreakprompttext2" placeholder="(Enter Assistant Postfix)"
value="" onload="togglejailbreak2();"></textarea> value="" onload="togglejailbreak2();"></textarea>
</div> </div>
</span> </span>
<span id="openrouterproviderbox" class="hidden"><br>Preferred Provider: <input style="height: 25px; font-size:12px;padding:4px;display:inline;width:calc(100% - 140px)" class="form-control" type="text" id="openrouterproviders" placeholder="(Automatic)" value=""> <span id="openrouterproviderbox" class="hidden"><br>Preferred Provider: <input title="Enter Preferred AI Provider" style="height: 25px; font-size:12px;padding:4px;display:inline;width:calc(100% - 140px)" class="form-control" type="text" id="openrouterproviders" placeholder="(Automatic)" value="">
<div style="display:inline;width:210px;"> <div style="display:inline;width:210px;">
</div> </div>
</span> </span>
@ -22643,7 +22662,7 @@ Current version indicated by LITEVER below.
<input class="form-control" type="text" id="custom_claude_endpoint" placeholder="Claude API URL" value=""> <input class="form-control" type="text" id="custom_claude_endpoint" placeholder="Claude API URL" value="">
<input class="form-control" type="password" id="custom_claude_key" placeholder="Claude API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br> <input class="form-control" type="password" id="custom_claude_key" placeholder="Claude API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
Model Choice:<br> Model Choice:<br>
<select style="padding:4px; width:calc(100% - 110px); display:inline-block" class="form-control" id="custom_claude_model" onload="toggleclaudemodel()" onchange="toggleclaudemodel()"> <select title="Claude AI Model Selection" style="padding:4px; width:calc(100% - 110px); display:inline-block" class="form-control" id="custom_claude_model" onload="toggleclaudemodel()" onchange="toggleclaudemodel()">
<option value="claude-v1">claude-v1</option> <option value="claude-v1">claude-v1</option>
<option value="claude-v1-100k">claude-v1-100k</option> <option value="claude-v1-100k">claude-v1-100k</option>
<option value="claude-instant-v1">claude-instant-v1</option> <option value="claude-instant-v1">claude-instant-v1</option>
@ -22661,21 +22680,21 @@ Current version indicated by LITEVER below.
<option value="claude-3-7-sonnet-20250219">claude-3-7-sonnet-20250219</option> <option value="claude-3-7-sonnet-20250219">claude-3-7-sonnet-20250219</option>
</select> </select>
<button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="claudefetchlist" onclick="claude_fetch_models()">Fetch List</button> <button type="button" class="btn btn-primary" style="display:inline;width:105px;" id="claudefetchlist" onclick="claude_fetch_models()">Fetch List</button>
<input type="checkbox" id="claudeaddversion" onchange="" checked> <input type="checkbox" title="Add endpoint version" id="claudeaddversion" onchange="" checked>
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div> <div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
<span id="clauderenamecompatdiv"> <span id="clauderenamecompatdiv">
<input type="checkbox" id="clauderenamecompat" onchange="" checked> <input type="checkbox" title="Claude Compatibility Rename Fix" id="clauderenamecompat" onchange="" checked>
<div class="box-label" title="Rename User and Bot tags to work with claude, force inject them otherwise">Claude Compatibility Rename Fix</div> <div class="box-label" title="Rename User and Bot tags to work with claude, force inject them otherwise">Claude Compatibility Rename Fix</div>
</span> </span>
<textarea class="form-control hidden" rows="2" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="claudesystemprompt" placeholder="(Enter System Prompt)" <textarea class="form-control hidden" title="Claude System Prompt" rows="2" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="claudesystemprompt" placeholder="(Enter System Prompt, which steers overall AI behavior.)"
value="" onload=""></textarea> value="" onload=""></textarea>
<textarea class="form-control hidden" rows="2" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="claudejailbreakprompt" placeholder="(Enter Assistant Postfix)" <textarea class="form-control hidden" title="Claude Assistant Postfix" rows="2" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="claudejailbreakprompt" placeholder="(Enter Assistant Postfix, which forces the AI to start responses with this text.)"
value="" onload=""></textarea> value="" onload=""></textarea>
<div id="claudethinkingbox" class="hidden"> <div id="claudethinkingbox" class="hidden">
<div class="box-label" title="Enable Thinking">Enable Thinking </div> <div class="box-label" title="Enable Thinking">Enable Thinking </div>
<input type="checkbox" style="display:inline;" id="claudethinking"> <input type="checkbox" title="Enable Thinking" style="display:inline;" id="claudethinking">
</div> </div>
</div> </div>
@ -22683,7 +22702,7 @@ Current version indicated by LITEVER below.
Uses Gemini by Google.<br><br> Uses Gemini by Google.<br><br>
Note that KoboldAI Lite takes no responsibility for your usage or consequences of this feature. Your API key is used directly with the Gemini API and is not transmitted to us.<br><br> Note that KoboldAI Lite takes no responsibility for your usage or consequences of this feature. Your API key is used directly with the Gemini API and is not transmitted to us.<br><br>
<div> <div>
<select style="padding:4px; width:calc(100% - 110px); display:inline-block" class="form-control" id="custom_gemini_model" onchange="togglegeminimodel()"> <select title="Gemini AI Model Selection" style="padding:4px; width:calc(100% - 110px); display:inline-block" class="form-control" id="custom_gemini_model" onchange="togglegeminimodel()">
<option value="gemini-1.5-flash-latest" selected="selected">gemini-1.5-flash-latest</option> <option value="gemini-1.5-flash-latest" selected="selected">gemini-1.5-flash-latest</option>
<option value="gemini-1.5-pro-001">gemini-1.5-pro-001</option> <option value="gemini-1.5-pro-001">gemini-1.5-pro-001</option>
<option value="gemini-1.5-pro-002">gemini-1.5-pro-002</option> <option value="gemini-1.5-pro-002">gemini-1.5-pro-002</option>
@ -22700,28 +22719,30 @@ Current version indicated by LITEVER below.
<div id="gemini_role_options"> <div id="gemini_role_options">
<div> <div>
Main Message Role: Main Message Role:
<select class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" onload="togglegeminirole();" onchange="togglegeminirole();" id="geminiroledropdown"> <select title="Main Message Role" class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" onload="togglegeminirole();" onchange="togglegeminirole();" id="geminiroledropdown">
<option value="" selected>Default</option> <option value="" selected>Default</option>
<option value="user">User</option> <option value="user">User</option>
<option value="model">Model</option> <option value="model">Model</option>
</select> </select>
</div> </div>
<div id="gemini_role_options2" style="display:flex"> <div id="gemini_role_options2" style="display:flex">
<select class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="gemini_postfix_role"> <select title="Postfix Message Role" class="form-control" style="height: 25px; font-size:12px; padding:4px;display:inline;width:100px" id="gemini_postfix_role">
<option value="user">User</option> <option value="user">User</option>
<option value="model" selected>Model</option> <option value="model" selected>Model</option>
</select> </select>
<textarea class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%;" type="text" id="gemini_postfix_text" placeholder="(Enter Gemini Postfix)" <textarea title="Gemini Postfix Prompt" class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%;" type="text" id="gemini_postfix_text" placeholder="(Enter Gemini Postfix)"
value=""></textarea> value=""></textarea>
</div> </div>
</div> </div>
<textarea class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="gemini_system_instruction" placeholder="(Enter System Instruction)" <textarea title="Gemini System Prompt" class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="gemini_system_instruction" placeholder="(Enter System Instruction)"
value=""></textarea><br> value=""></textarea><br>
<input type="checkbox" title="Use Gemini WebSearch" id="usegeminiweb">
<div class="box-label">Use WebSearch</div><br>
</div> </div>
<div id="coherecustom" class="menutext hidden"> <div id="coherecustom" class="menutext hidden">
Uses Cohere's models through their own API.<br><br> Uses Cohere's models through their own API.<br><br>
Note that KoboldAI Lite takes no responsibility for your usage or consequences of this feature. Your API key is used directly with the Cohere API and is not transmitted to us.<br><br> Note that KoboldAI Lite takes no responsibility for your usage or consequences of this feature. Your API key is used directly with the Cohere API and is not transmitted to us.<br><br>
<select style="padding:4px;" class="form-control" id="custom_cohere_model"> <select title="Cohere AI Model Selection" style="padding:4px;" class="form-control" id="custom_cohere_model">
<option value="command" selected="selected">command</option> <option value="command" selected="selected">command</option>
<option value="command-r">command-r</option> <option value="command-r">command-r</option>
<option value="command-r-plus">command-r-plus</option> <option value="command-r-plus">command-r-plus</option>
@ -22732,10 +22753,10 @@ Current version indicated by LITEVER below.
</select> </select>
<span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br> <span class="color_green" style="font-weight: bold;">Please input Cohere API Key.</span><br><br>
<input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br> <input class="form-control" type="password" id="custom_cohere_key" placeholder="Cohere API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
<input type="checkbox" id="usecohereweb"> <input type="checkbox" title="Use Cohere WebSearch" id="usecohereweb">
<div class="box-label" id="usecohereweblabel">Use WebSearch</div> <div class="box-label">Use WebSearch</div>
<input type="checkbox" id="useocoherepreamble" onchange="togglecoherepreamble()"> <input type="checkbox" title="Use Cohere Preamble" id="useocoherepreamble" onchange="togglecoherepreamble()">
<div class="box-label" id="useocoherepreamblelabel">Use Preamble</div> <div class="box-label">Use Preamble</div>
<span id="useocoherepreamblebox" class="hidden" onload="togglecoherepreamble();"> <span id="useocoherepreamblebox" class="hidden" onload="togglecoherepreamble();">
<textarea class="form-control" id="cohere_preamble" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" placeholder="(Enter Preamble)" value=""></textarea> <textarea class="form-control" id="cohere_preamble" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" placeholder="(Enter Preamble)" value=""></textarea>

View file

@ -75,6 +75,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" }, { LLM_KV_GENERAL_VERSION, "general.version" },

View file

@ -79,6 +79,7 @@ enum llm_kv {
LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION, LLM_KV_GENERAL_VERSION,

View file

@ -255,7 +255,8 @@ llama_context::llama_context(
model.n_devices() > 1 && model.n_devices() > 1 &&
model.params.n_gpu_layers > (int) model.hparams.n_layer && model.params.n_gpu_layers > (int) model.hparams.n_layer &&
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
cparams.offload_kqv; cparams.offload_kqv &&
!model.has_tensor_overrides();
// pipeline parallelism requires support for async compute and events in all devices // pipeline parallelism requires support for async compute and events in all devices
if (pipeline_parallel) { if (pipeline_parallel) {
@ -1202,33 +1203,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens; const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
// TODO: remove this stuff llama_kv_cache_guard kv_guard(kv_self.get());
class batch_guard {
public:
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(*kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@ -1281,6 +1256,9 @@ int llama_context::decode(llama_batch & inp_batch) {
return -2; return -2;
}; };
// handle any pending defrags/shifts
kv_self_update();
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) { while (sbatch.n_tokens > 0) {
@ -1320,22 +1298,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot // find KV slot
{ {
kv_self_update(); if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
// if we have enough unused cells before the current head -> return 1;
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self->head = 0;
} }
const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
if (!kv_self->recurrent) { if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
@ -1372,16 +1340,6 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
} }
// update the kv ring buffer
{
kv_self->head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self->head >= kv_self->size) {
kv_self->head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes) // plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) { //if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot"); // ggml_graph_dump_dot(gf, NULL, "llama.dot");
@ -1468,7 +1426,7 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
// finalize the batch processing // finalize the batch processing
bg.done(); kv_guard.commit();
// set output mappings // set output mappings
{ {

View file

@ -11,8 +11,6 @@
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
} }
@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
return false; return false;
} }
} }
return true;
} }
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
} }
} }
void llama_kv_cache_unified::restore() {
if (pending.ranges.empty()) {
return;
}
// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
seq_rm(-1, -1, -1);
return;
}
uint32_t new_head = size;
for (auto & range : pending.ranges) {
for (uint32_t i = range.c0; i < range.c1; ++i) {
cells[i].seq_id.clear();
// keep count of the number of used cells
if (cells[i].pos >= 0) {
used--;
}
cells[i].pos = -1;
cells[i].src = -1;
}
new_head = std::min(new_head, range.c0);
}
if (new_head != size && new_head < head) {
head = new_head;
}
}
void llama_kv_cache_unified::commit() {
if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
return;
}
pending.ranges.clear();
}
bool llama_kv_cache_unified::get_can_shift() const { bool llama_kv_cache_unified::get_can_shift() const {
return can_shift; return can_shift;
} }
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( bool llama_kv_cache_unified::find_slot(
const llama_ubatch & ubatch) { const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens; const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*ubatch.n_tokens) {
head = 0;
}
if (recurrent) { if (recurrent) {
// For recurrent state architectures (like Mamba or RWKV), // For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence. // each cache cell can store the state for a whole sequence.
@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
// too big seq_id // too big seq_id
// TODO: would it be possible to resize the cache instead? // TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
return llama_kv_cache_slot_info_failed; return false;
} }
if (j > 0) { if (j > 0) {
llama_kv_cell & seq = cells[seq_id]; llama_kv_cell & seq = cells[seq_id];
@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
[](const llama_kv_cell& cell){ return !cell.is_empty(); }); [](const llama_kv_cell& cell){ return !cell.is_empty(); });
// sanity check // sanity check
return llama_kv_cache_slot_info(n >= n_seqs); return n >= n_seqs;
} }
// otherwise, one cell per token. // otherwise, one cell per token.
if (n_tokens > size) { if (n_tokens > size) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
return llama_kv_cache_slot_info_failed; return false;
} }
uint32_t n_tested = 0; uint32_t n_tested = 0;
@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
if (n_tested >= size) { if (n_tested >= size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return llama_kv_cache_slot_info_failed; return false;
} }
} }
@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
used += n_tokens; used += n_tokens;
return llama_kv_cache_slot_info(head, head + n_tokens); pending.ranges.push_back({head, head + n_tokens});
return true;
} }
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
commit();
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells // Assume that this is one contiguous block of cells

View file

@ -17,6 +17,9 @@ struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i { struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i; using llama_memory_i::llama_memory_i;
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_n_tokens() const = 0; virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
@ -25,6 +28,21 @@ struct llama_kv_cache : public llama_memory_i {
bool get_can_edit() const override { return get_can_shift(); } bool get_can_edit() const override { return get_can_shift(); }
}; };
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
~llama_kv_cache_guard() {
kv->restore();
}
void commit() {
kv->commit();
}
private:
llama_kv_cache * kv;
};
struct llama_kv_cell { struct llama_kv_cell {
llama_pos pos = -1; llama_pos pos = -1;
llama_pos delta = 0; llama_pos delta = 0;
@ -46,17 +64,6 @@ struct llama_kv_cell {
} }
}; };
// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
operator bool() const { return found; }
};
// ring-buffer of cached KV data // ring-buffer of cached KV data
// TODO: pimpl // TODO: pimpl
// TODO: add notion of max sequences // TODO: add notion of max sequences
@ -93,6 +100,9 @@ public:
void clear() override; void clear() override;
void defrag() override; void defrag() override;
virtual void restore() override;
virtual void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override; void seq_keep(llama_seq_id seq_id) override;
@ -105,10 +115,9 @@ public:
// find an empty slot of size "n_tokens" in the cache // find an empty slot of size "n_tokens" in the cache
// updates the cache head // updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points // Note: On success, it's important that cache.head points
// to the first cell of the slot. // to the first cell of the slot.
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch); bool find_slot(const llama_ubatch & batch);
// TODO: maybe not needed // TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const; uint32_t get_padding(const llama_cparams & cparams) const;
@ -128,7 +137,19 @@ public:
// return true if cells have been moved // return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes); bool defrag_prepare(int32_t n_max_nodes);
// state save/load // commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
@ -183,59 +204,6 @@ private:
// using llama_kv_cache_unified::llama_kv_cache_unified; // using llama_kv_cache_unified::llama_kv_cache_unified;
//}; //};
//
// kv cache restore
//
// saves the kv_cache state for future recovery.
// used to rollback llama_kv_cache_find_slot changes.
struct llama_kv_slot_restorer {
struct llama_kv_cache_state {
uint32_t head = 0;
uint32_t n = 0;
} old_state;
// for non-recurrent models only
// list of slots to restore
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
bool do_restore = false;
llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
// saves a slot information for future restoration
void save(const llama_kv_cache_slot_info & slot) {
if (slot) {
do_restore = true;
if (slot.boundaries.first != slot.boundaries.second) {
slot_boundaries.push_back(slot.boundaries);
}
}
}
// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore() {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
cache.seq_rm(-1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {
cache.seq_rm(-1, slot.first, slot.second);
}
}
}
}
};
// TODO: maybe become part of the public llama_kv_cache in the future // TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv); int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);

View file

@ -449,7 +449,8 @@ llama_model_loader::llama_model_loader(
std::vector<std::string> & splits, std::vector<std::string> & splits,
bool use_mmap, bool use_mmap,
bool check_tensors, bool check_tensors,
const struct llama_model_kv_override * param_overrides_p) { const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) {
int trace = 0; int trace = 0;
if (getenv("LLAMA_TRACE")) { if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE")); trace = atoi(getenv("LLAMA_TRACE"));
@ -461,6 +462,8 @@ llama_model_loader::llama_model_loader(
} }
} }
tensor_buft_overrides = param_tensor_buft_overrides_p;
// Load the main GGUF // Load the main GGUF
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
struct gguf_init_params params = { struct gguf_init_params params = {
@ -605,7 +608,9 @@ llama_model_loader::llama_model_loader(
if (trace > 0) { if (trace > 0) {
const uint16_t sid = w.idx; const uint16_t sid = w.idx;
LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ] %8.2f MiB\n", __func__,
sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str(),
ggml_nbytes(tensor)/1024.0f/1024.0f);
} }
} }
@ -645,9 +650,9 @@ llama_model_loader::llama_model_loader(
ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
{ {
const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV uint32_t ftype_val = 0;
if (kid >= 0) { if (get_key(LLM_KV_GENERAL_FILE_TYPE, ftype_val, false)) {
ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid); ftype = (llama_ftype) ftype_val;
} }
} }

View file

@ -77,8 +77,9 @@ struct llama_model_loader {
llama_mmaps mappings; llama_mmaps mappings;
std::map<std::string, struct llama_tensor_weight, weight_name_comparer> weights_map; std::map<std::string, llama_tensor_weight, weight_name_comparer> weights_map;
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides; std::unordered_map<std::string, llama_model_kv_override> kv_overrides;
const llama_model_tensor_buft_override * tensor_buft_overrides;
gguf_context_ptr meta; gguf_context_ptr meta;
std::vector<ggml_context_ptr> contexts; std::vector<ggml_context_ptr> contexts;
@ -95,7 +96,8 @@ struct llama_model_loader {
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap, bool use_mmap,
bool check_tensors, bool check_tensors,
const struct llama_model_kv_override * param_overrides_p); const llama_model_kv_override * param_overrides_p,
const llama_model_tensor_buft_override * param_tensor_buft_overrides_p);
template<typename T> template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type typename std::enable_if<std::is_integral<T>::value, bool>::type

View file

@ -17,6 +17,7 @@
#include <cmath> #include <cmath>
#include <functional> #include <functional>
#include <map> #include <map>
#include <regex>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <iostream> #include <iostream>
@ -383,9 +384,12 @@ struct llama_model::impl {
layer_dev dev_input = {}; layer_dev dev_input = {};
layer_dev dev_output = {}; layer_dev dev_output = {};
std::vector<layer_dev> dev_layer; std::vector<layer_dev> dev_layer;
bool has_tensor_overrides;
}; };
llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) { llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
} }
llama_model::~llama_model() {} llama_model::~llama_model() {}
@ -1586,10 +1590,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
} }
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); ggml_backend_buffer_type_t buft = nullptr;
// check overrides
if (ml.tensor_buft_overrides) {
std::string tensor_name = tn.str();
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
std::regex pattern(overrides->pattern);
if (std::regex_search(tensor_name, pattern)) {
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
buft = overrides->buft;
break;
}
}
}
if (!buft) {
buft = select_weight_buft(hparams, t_meta, op, *buft_list);
if (!buft) { if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
} }
}
// avoid using a host buffer when using mmap // avoid using a host buffer when using mmap
auto * buft_dev = ggml_backend_buft_get_device(buft); auto * buft_dev = ggml_backend_buft_get_device(buft);
@ -4250,6 +4271,10 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
}); });
} }
bool llama_model::has_tensor_overrides() const {
return pimpl->has_tensor_overrides;
}
const ggml_tensor * llama_model::get_tensor(const char * name) const { const ggml_tensor * llama_model::get_tensor(const char * name) const {
auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
[name](const std::pair<std::string, ggml_tensor *> & it) { [name](const std::pair<std::string, ggml_tensor *> & it) {
@ -12422,6 +12447,7 @@ llm_graph_result_ptr llama_model::build_graph(
llama_model_params llama_model_default_params() { llama_model_params llama_model_default_params() {
llama_model_params result = { llama_model_params result = {
/*.devices =*/ nullptr, /*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
/*.n_gpu_layers =*/ 0, /*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,

View file

@ -382,6 +382,8 @@ struct llama_model {
ggml_backend_buffer_type_t select_buft(int il) const; ggml_backend_buffer_type_t select_buft(int il) const;
bool has_tensor_overrides() const;
const struct ggml_tensor * get_tensor(const char * name) const; const struct ggml_tensor * get_tensor(const char * name) const;
// TODO: move this to new llm_arch_model_i interface // TODO: move this to new llm_arch_model_i interface

View file

@ -530,7 +530,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} }
std::vector<std::string> splits = {}; std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides, nullptr);
ml.init_mappings(false); // no prefetching ml.init_mappings(false); // no prefetching
llama_model model(llama_model_default_params()); llama_model model(llama_model_default_params());

View file

@ -636,7 +636,8 @@ struct llm_tokenizer_bpe : llm_tokenizer {
regex_exprs = { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+"
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?)
"'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
}; };
break; break;
default: default:

View file

@ -120,7 +120,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
model.t_start_us = tm.t_start_us; model.t_start_us = tm.t_start_us;
try { try {
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides); llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides, params.tensor_buft_overrides);
ml.print_info(); ml.print_info();