diff --git a/Makefile b/Makefile index b5b0481ae..47bd6d9fc 100644 --- a/Makefile +++ b/Makefile @@ -764,23 +764,23 @@ clean: rm -vrf llguidance # useful tools -main: tools/main/main.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) +main: tools/main/main.cpp common/arg.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -mainvk: tools/main/main.cpp common/arg.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib +mainvk: tools/main/main.cpp common/arg.cpp common/download.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib $(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS) sdmain: otherarch/sdcpp/util.cpp otherarch/sdcpp/main.cpp otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/tokenize_util.cpp otherarch/sdcpp/thirdparty/zip.c build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -ttsmain: tools/tts/tts.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) +ttsmain: tools/tts/tts.cpp common/arg.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) gguf-split: tools/gguf-split/gguf-split.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o build-info.h llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -mtmd-cli: tools/mtmd/mtmd-cli.cpp tools/mtmd/mtmd.cpp tools/mtmd/mtmd-helper.cpp tools/mtmd/clip.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) +mtmd-cli: tools/mtmd/mtmd-cli.cpp tools/mtmd/mtmd.cpp tools/mtmd/mtmd-helper.cpp tools/mtmd/clip.cpp common/arg.cpp common/download.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp common/arg.cpp src/llama-cparams.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) +embedding: examples/embedding/embedding.cpp common/arg.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -embeddingvk: examples/embedding/embedding.cpp common/arg.cpp src/llama-cparams.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib +embeddingvk: examples/embedding/embedding.cpp common/arg.cpp common/download.cpp src/llama-cparams.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o ggml-vulkan-shaders.o ggml-repack.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib $(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS) ttscppmain: otherarch/ttscpp/cli/cli.cpp otherarch/ttscpp/cli/playback.cpp otherarch/ttscpp/cli/playback.h otherarch/ttscpp/cli/write_file.cpp otherarch/ttscpp/cli/write_file.h otherarch/ttscpp/cli/vad.cpp otherarch/ttscpp/cli/vad.h otherarch/ttscpp/src/ttscpp.cpp otherarch/ttscpp/src/ttstokenizer.cpp otherarch/ttscpp/src/ttssampler.cpp otherarch/ttscpp/src/parler_model.cpp otherarch/ttscpp/src/dac_model.cpp otherarch/ttscpp/src/ttsutil.cpp otherarch/ttscpp/src/ttsargs.cpp otherarch/ttscpp/src/ttst5_encoder_model.cpp otherarch/ttscpp/src/phonemizer.cpp otherarch/ttscpp/src/tts_model.cpp otherarch/ttscpp/src/kokoro_model.cpp otherarch/ttscpp/src/dia_model.cpp otherarch/ttscpp/src/orpheus_model.cpp otherarch/ttscpp/src/snac_model.cpp otherarch/ttscpp/src/general_neural_audio_codec.cpp ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o ggml-repack.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/common/arg.cpp b/common/arg.cpp index af722557f..a340558b9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,12 +2,12 @@ #include "chat.h" #include "common.h" -#include "gguf.h" // for reading GGUF splits #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" #include "chat.h" #include "build-info.h" +#include "download.h" // fix problem with std::min and std::max #if defined(_WIN32) @@ -24,23 +24,14 @@ #include #include #include -#include #include -#include #include #include #include #include -#include +#include // for hardware_concurrency #include -#if defined(LLAMA_USE_CURL) -#include -#include -#elif defined(LLAMA_USE_HTTPLIB) -#include "http.h" -#endif - #ifdef __linux__ #include #elif defined(_WIN32) @@ -54,16 +45,9 @@ #endif #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 -// isatty -#if defined(_WIN32) -#include -#else -#include -#endif - using json = nlohmann::ordered_json; -std::initializer_list mmproj_examples = { +static std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, }; @@ -78,50 +62,13 @@ static std::string read_file(const std::string & fname) { return content; } -static void write_file(const std::string & fname, const std::string & content) { - const std::string fname_tmp = fname + ".tmp"; - std::ofstream file(fname_tmp); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); - } - - try { - file << content; - file.close(); - - // Makes write atomic - if (rename(fname_tmp.c_str(), fname.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str()); - // If rename fails, try to delete the temporary file - if (remove(fname_tmp.c_str()) != 0) { - LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str()); - } - } - } catch (...) { - // If anything fails, try to delete the temporary file - if (remove(fname_tmp.c_str()) != 0) { - LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str()); - } - - throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str())); - } -} - -static bool is_output_a_tty() { -#if defined(_WIN32) - return _isatty(_fileno(stdout)); -#else - return isatty(1); -#endif -} - common_arg & common_arg::set_examples(std::initializer_list examples) { - this->examples = std::move(examples); + this->examples = examples; return *this; } common_arg & common_arg::set_excludes(std::initializer_list excludes) { - this->excludes = std::move(excludes); + this->excludes = excludes; return *this; } @@ -144,7 +91,7 @@ bool common_arg::is_exclude(enum llama_example ex) { return excludes.find(ex) != excludes.end(); } -bool common_arg::get_value_from_env(std::string & output) { +bool common_arg::get_value_from_env(std::string & output) const { if (env == nullptr) return false; char * value = std::getenv(env); if (value) { @@ -154,7 +101,7 @@ bool common_arg::get_value_from_env(std::string & output) { return false; } -bool common_arg::has_value_from_env() { +bool common_arg::has_value_from_env() const { return env != nullptr && std::getenv(env); } @@ -222,964 +169,6 @@ std::string common_arg::to_string() { return ss.str(); } -// -// downloader -// - -struct common_hf_file_res { - std::string repo; // repo name with ":tag" removed - std::string ggufFile; - std::string mmprojFile; -}; - -static void write_etag(const std::string & path, const std::string & etag) { - const std::string etag_path = path + ".etag"; - write_file(etag_path, etag); - LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str()); -} - -static std::string read_etag(const std::string & path) { - std::string none; - const std::string etag_path = path + ".etag"; - - if (std::filesystem::exists(etag_path)) { - std::ifstream etag_in(etag_path); - if (!etag_in) { - LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); - return none; - } - std::string etag; - std::getline(etag_in, etag); - return etag; - } - - // no etag file, but maybe there is an old .json - // remove this code later - const std::string metadata_path = path + ".json"; - - if (std::filesystem::exists(metadata_path)) { - std::ifstream metadata_in(metadata_path); - try { - nlohmann::json metadata_json; - metadata_in >> metadata_json; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata_json.dump().c_str()); - if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { - std::string etag = metadata_json.at("etag"); - write_etag(path, etag); - if (!std::filesystem::remove(metadata_path)) { - LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); - } - return etag; - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } - } - return none; -} - -#ifdef LLAMA_USE_CURL - -// -// CURL utils -// - -using curl_ptr = std::unique_ptr; - -// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one -struct curl_slist_ptr { - struct curl_slist * ptr = nullptr; - ~curl_slist_ptr() { - if (ptr) { - curl_slist_free_all(ptr); - } - } -}; - -static CURLcode common_curl_perf(CURL * curl) { - CURLcode res = curl_easy_perform(curl); - if (res != CURLE_OK) { - LOG_ERR("%s: curl_easy_perform() failed\n", __func__); - } - - return res; -} - -// Send a HEAD request to retrieve the etag and last-modified headers -struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; - std::string accept_ranges; -}; - -struct FILE_deleter { - void operator()(FILE * f) const { fclose(f); } -}; - -static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase); - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } else if (std::regex_match(key, match, accept_ranges_regex)) { - headers->accept_ranges = value; - } - } - - return n_items; -} - -static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) { - return std::fwrite(data, size, nmemb, static_cast(fd)); -} - -// helper function to hide password in URL -static std::string llama_download_hide_password_in_url(const std::string & url) { - // Use regex to match and replace the user[:password]@ pattern in URLs - // Pattern: scheme://[user[:password]@]host[...] - static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)"); - std::smatch match; - - if (std::regex_match(url, match, url_regex)) { - // match[1] = scheme (e.g., "https://") - // match[2] = user[:password]@ part - // match[3] = rest of URL (host and path) - return match[1].str() + "********@" + match[3].str(); - } - - return url; // No credentials found or malformed URL -} - -static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) { - // Set the URL, allow to follow http redirection - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - -# if defined(_WIN32) - // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of - // operating system. Currently implemented under MS-Windows. - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -# endif - - curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback); -} - -static void common_curl_easy_setopt_get(CURL * curl) { - curl_easy_setopt(curl, CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback); - - // display download progress - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); -} - -static bool common_pull_file(CURL * curl, const std::string & path_temporary) { - if (std::filesystem::exists(path_temporary)) { - const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary)); - LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str()); - const std::string range_str = partial_size + "-"; - curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str()); - } - - // Always open file in append mode could be resuming - std::unique_ptr outfile(fopen(path_temporary.c_str(), "ab")); - if (!outfile) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str()); - return false; - } - - common_curl_easy_setopt_get(curl); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get()); - - return common_curl_perf(curl) == CURLE_OK; -} - -static bool common_download_head(CURL * curl, - curl_slist_ptr & http_headers, - const std::string & url, - const std::string & bearer_token) { - if (!curl) { - LOG_ERR("%s: error initializing libcurl\n", __func__); - return false; - } - - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - // Check if hf-token or bearer-token was specified - if (!bearer_token.empty()) { - std::string auth_header = "Authorization: Bearer " + bearer_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr); - common_curl_easy_setopt_head(curl, url); - return common_curl_perf(curl) == CURLE_OK; -} - -// download one single file from remote URL to local path -static bool common_download_file_single_online(const std::string & url, - const std::string & path, - const std::string & bearer_token) { - static const int max_attempts = 3; - static const int retry_delay_seconds = 2; - for (int i = 0; i < max_attempts; ++i) { - std::string etag; - - // Check if the file already exists locally - const auto file_exists = std::filesystem::exists(path); - if (file_exists) { - etag = read_etag(path); - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - bool head_request_ok = false; - bool should_download = !file_exists; // by default, we should download if the file does not exist - - // Initialize libcurl - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - common_load_model_from_url_headers headers; - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - curl_slist_ptr http_headers; - const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); - if (!was_perform_successful) { - head_request_ok = false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code == 200) { - head_request_ok = true; - } else { - LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - head_request_ok = false; - } - - // if head_request_ok is false, we don't have the etag or last-modified headers - // we leave should_download as-is, which is true if the file does not exist - bool should_download_from_scratch = false; - if (head_request_ok) { - // check if ETag or Last-Modified headers are different - // if it is, we need to download the file again - if (!etag.empty() && etag != headers.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), - headers.etag.c_str()); - should_download = true; - should_download_from_scratch = true; - } - } - - const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none"; - if (should_download) { - if (file_exists && - !accept_ranges_supported) { // Resumable downloads not supported, delete and start again. - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - - const std::string path_temporary = path + ".downloadInProgress"; - if (should_download_from_scratch) { - if (std::filesystem::exists(path_temporary)) { - if (remove(path_temporary.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); - return false; - } - } - - if (std::filesystem::exists(path)) { - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - } - if (head_request_ok) { - write_etag(path, headers.etag); - } - - // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", - __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(), - headers.etag.c_str(), headers.last_modified.c_str()); - const bool was_pull_successful = common_pull_file(curl.get(), path_temporary); - if (!was_pull_successful) { - if (i + 1 < max_attempts) { - const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; - LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else { - LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); - } - - continue; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { - LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; - } - - if (rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; - } - } else { - LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - } - - break; - } - - return true; -} - -std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::vector res_buffer; - - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - auto data_vec = static_cast *>(data); - data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - if (params.timeout > 0) { - curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); - } - if (params.max_size > 0) { - curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); - } - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - for (const auto & header : params.headers) { - http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); - } - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); - - CURLcode res = curl_easy_perform(curl.get()); - - if (res != CURLE_OK) { - std::string error_msg = curl_easy_strerror(res); - throw std::runtime_error("error: cannot make GET request: " + error_msg); - } - - long res_code; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - - return { res_code, std::move(res_buffer) }; -} - -#else - -#ifdef LLAMA_USE_HTTPLIB -static void print_progress(size_t current, size_t total) { - if (!is_output_a_tty()) { - return; - } - - if (!total) { - return; - } - - size_t width = 50; - size_t pct = (100 * current) / total; - size_t pos = (width * current) / total; - - std::cout << "[" - << std::string(pos, '=') - << (pos < width ? ">" : "") - << std::string(width - pos, ' ') - << "] " << std::setw(3) << pct << "% (" - << current / (1024 * 1024) << " MB / " - << total / (1024 * 1024) << " MB)\r"; - std::cout.flush(); -} - -static bool common_pull_file(httplib::Client & cli, - const std::string & resolve_path, - const std::string & path_tmp, - bool supports_ranges, - size_t existing_size, - size_t & total_size) { - std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app); - if (!ofs.is_open()) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str()); - return false; - } - - httplib::Headers headers; - if (supports_ranges && existing_size > 0) { - headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); - } - - std::atomic downloaded{existing_size}; - - auto res = cli.Get(resolve_path, headers, - [&](const httplib::Response &response) { - if (existing_size > 0 && response.status != 206) { - LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status); - return false; - } - if (existing_size == 0 && response.status != 200) { - LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status); - return false; - } - if (total_size == 0 && response.has_header("Content-Length")) { - try { - size_t content_length = std::stoull(response.get_header_value("Content-Length")); - total_size = existing_size + content_length; - } catch (const std::exception &e) { - LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what()); - } - } - return true; - }, - [&](const char *data, size_t len) { - ofs.write(data, len); - if (!ofs) { - LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str()); - return false; - } - downloaded += len; - print_progress(downloaded, total_size); - return true; - }, - nullptr - ); - - std::cout << "\n"; - - if (!res) { - LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); - return false; - } - - return true; -} - -// download one single file from remote URL to local path -static bool common_download_file_single_online(const std::string & url, - const std::string & path, - const std::string & bearer_token) { - static const int max_attempts = 3; - static const int retry_delay_seconds = 2; - - auto [cli, parts] = common_http_client(url); - - httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; - if (!bearer_token.empty()) { - default_headers.insert({"Authorization", "Bearer " + bearer_token}); - } - cli.set_default_headers(default_headers); - - const bool file_exists = std::filesystem::exists(path); - - std::string last_etag; - if (file_exists) { - last_etag = read_etag(path); - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - for (int i = 0; i < max_attempts; ++i) { - auto head = cli.Head(parts.path); - bool head_ok = head && head->status >= 200 && head->status < 300; - if (!head_ok) { - LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); - if (file_exists) { - LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); - return true; - } - } - - std::string etag; - if (head_ok && head->has_header("ETag")) { - etag = head->get_header_value("ETag"); - } - - size_t total_size = 0; - if (head_ok && head->has_header("Content-Length")) { - try { - total_size = std::stoull(head->get_header_value("Content-Length")); - } catch (const std::exception& e) { - LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); - } - } - - bool supports_ranges = false; - if (head_ok && head->has_header("Accept-Ranges")) { - supports_ranges = head->get_header_value("Accept-Ranges") != "none"; - } - - bool should_download_from_scratch = false; - if (!last_etag.empty() && !etag.empty() && last_etag != etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, - last_etag.c_str(), etag.c_str()); - should_download_from_scratch = true; - } - - if (file_exists) { - if (!should_download_from_scratch) { - LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - return true; - } - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return false; - } - } - - const std::string path_temporary = path + ".downloadInProgress"; - size_t existing_size = 0; - - if (std::filesystem::exists(path_temporary)) { - if (supports_ranges && !should_download_from_scratch) { - existing_size = std::filesystem::file_size(path_temporary); - } else if (remove(path_temporary.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); - return false; - } - } - - // start the download - LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); - const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); - if (!was_pull_successful) { - if (i + 1 < max_attempts) { - const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; - LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else { - LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); - } - continue; - } - - if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; - } - if (!etag.empty()) { - write_etag(path, etag); - } - break; - } - - return true; -} - -std::pair> common_remote_get_content(const std::string & url, - const common_remote_params & params) { - auto [cli, parts] = common_http_client(url); - - httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; - for (const auto & header : params.headers) { - size_t pos = header.find(':'); - if (pos != std::string::npos) { - headers.emplace(header.substr(0, pos), header.substr(pos + 1)); - } else { - headers.emplace(header, ""); - } - } - - if (params.timeout > 0) { - cli.set_read_timeout(params.timeout, 0); - cli.set_write_timeout(params.timeout, 0); - } - - std::vector buf; - auto res = cli.Get(parts.path, headers, - [&](const char *data, size_t len) { - buf.insert(buf.end(), data, data + len); - return params.max_size == 0 || - buf.size() <= static_cast(params.max_size); - }, - nullptr - ); - - if (!res) { - throw std::runtime_error("error: cannot make GET request"); - } - - return { res->status, std::move(buf) }; -} - -#else //no httplib - -bool common_has_curl() { - return false; -} - -static bool common_download_file_single_online(const std::string &, const std::string &, const std::string &) { - LOG_ERR("error: built without CURL, cannot download model from internet\n"); - return false; -} - -std::pair> common_remote_get_content(const std::string & url, const common_remote_params &) { - if (!url.empty()) { - throw std::runtime_error("error: built without CURL, cannot download model from the internet"); - } - - return {}; -} -#endif - -#endif // LLAMA_USE_CURL - -static bool common_download_file_single(const std::string & url, - const std::string & path, - const std::string & bearer_token, - bool offline) { - if (!offline) { - return common_download_file_single_online(url, path, bearer_token); - } - - if (!std::filesystem::exists(path)) { - LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); - return false; - } - - LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); - return true; -} - -// download multiple files from remote URLs to local paths -// the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { - // Prepare download in parallel - std::vector> futures_download; - for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); - }, item)); - } - - // Wait for all downloads to complete - for (auto & f : futures_download) { - if (!f.get()) { - return false; - } - } - - return true; -} - -static bool common_download_model( - const common_params_model & model, - const std::string & bearer_token, - bool offline) { - // Basic validation of the model.url - if (model.url.empty()) { - LOG_ERR("%s: invalid model url\n", __func__); - return false; - } - - if (!common_download_file_single(model.url, model.path, bearer_token, offline)) { - return false; - } - - // check for additional GGUFs split to download - int n_split = 0; - { - struct gguf_init_params gguf_params = { - /*.no_alloc = */ true, - /*.ctx = */ NULL, - }; - auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); - if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); - return false; - } - - auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); - if (key_n_split >= 0) { - n_split = gguf_get_val_u16(ctx_gguf, key_n_split); - } - - gguf_free(ctx_gguf); - } - - if (n_split > 1) { - char split_prefix[PATH_MAX] = {0}; - char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0}; - - // Verify the first split file format - // and extract split URL and PATH prefixes - { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); - return false; - } - - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); - return false; - } - } - - std::vector> urls; - for (int idx = 1; idx < n_split; idx++) { - char split_path[PATH_MAX] = {0}; - llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); - - char split_url[LLAMA_MAX_URL_LENGTH] = {0}; - llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); - - if (std::string(split_path) == model.path) { - continue; // skip the already downloaded file - } - - urls.push_back({split_url, split_path}); - } - - // Download in parallel - common_download_file_multiple(urls, bearer_token, offline); - } - - return true; -} - -/** - * Allow getting the HF file from the HF repo with tag (like ollama), for example: - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 - * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M - * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s - * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) - * - * Return pair of (with "repo" already having tag removed) - * - * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. - */ -static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { - auto parts = string_split(hf_repo_with_tag, ':'); - std::string tag = parts.size() > 1 ? parts.back() : "latest"; - std::string hf_repo = parts[0]; - if (string_split(hf_repo, '/').size() != 2) { - throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); - } - - std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; - - // headers - std::vector headers; - headers.push_back("Accept: application/json"); - if (!bearer_token.empty()) { - headers.push_back("Authorization: Bearer " + bearer_token); - } - // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - // User-Agent header is already set in common_remote_get_content, no need to set it here - - // we use "=" to avoid clashing with other component, while still being allowed on windows - std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; - string_replace_all(cached_response_fname, "/", "_"); - std::string cached_response_path = fs_get_cache_file(cached_response_fname); - - // make the request - common_remote_params params; - params.headers = headers; - long res_code = 0; - std::string res_str; - bool use_cache = false; - if (!offline) { - try { - auto res = common_remote_get_content(url, params); - res_code = res.first; - res_str = std::string(res.second.data(), res.second.size()); - } catch (const std::exception & e) { - LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); - } - } - if (res_code == 0) { - if (std::filesystem::exists(cached_response_path)) { - LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); - res_str = read_file(cached_response_path); - res_code = 200; - use_cache = true; - } else { - throw std::runtime_error( - offline ? "error: failed to get manifest (offline mode)" - : "error: failed to get manifest (check your internet connection)"); - } - } - std::string ggufFile; - std::string mmprojFile; - - if (res_code == 200 || res_code == 304) { - try { - auto j = json::parse(res_str); - - if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) { - ggufFile = j["ggufFile"]["rfilename"].get(); - } - if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) { - mmprojFile = j["mmprojFile"]["rfilename"].get(); - } - } catch (const std::exception & e) { - throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what()); - } - if (!use_cache) { - // if not using cached response, update the cache file - write_file(cached_response_path, res_str); - } - } else if (res_code == 401) { - throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); - } else { - throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); - } - - // check response - if (ggufFile.empty()) { - throw std::runtime_error("error: model does not have ggufFile"); - } - - return { hf_repo, ggufFile, mmprojFile }; -} - -// -// Docker registry functions -// - -static std::string common_docker_get_token(const std::string & repo) { - std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; - - common_remote_params params; - auto res = common_remote_get_content(url, params); - - if (res.first != 200) { - throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); - } - - std::string response_str(res.second.begin(), res.second.end()); - nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); - - if (!response.contains("token")) { - throw std::runtime_error("Docker registry token response missing 'token' field"); - } - - return response["token"].get(); -} - -static std::string common_docker_resolve_model(const std::string & docker) { - // Parse ai/smollm2:135M-Q4_0 - size_t colon_pos = docker.find(':'); - std::string repo, tag; - if (colon_pos != std::string::npos) { - repo = docker.substr(0, colon_pos); - tag = docker.substr(colon_pos + 1); - } else { - repo = docker; - tag = "latest"; - } - - // ai/ is the default - size_t slash_pos = docker.find('/'); - if (slash_pos == std::string::npos) { - repo.insert(0, "ai/"); - } - - LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str()); - try { - // --- helper: digest validation --- - auto validate_oci_digest = [](const std::string & digest) -> std::string { - // Expected: algo:hex ; start with sha256 (64 hex chars) - // You can extend this map if supporting other algorithms in future. - static const std::regex re("^sha256:([a-fA-F0-9]{64})$"); - std::smatch m; - if (!std::regex_match(digest, m, re)) { - throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest); - } - // normalize hex to lowercase - std::string normalized = digest; - std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){ - return std::tolower(c); - }); - return normalized; - }; - - std::string token = common_docker_get_token(repo); // Get authentication token - - // Get manifest - const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; - std::string manifest_url = url_prefix + "/manifests/" + tag; - common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); - auto manifest_res = common_remote_get_content(manifest_url, manifest_params); - if (manifest_res.first != 200) { - throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); - } - - std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); - nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); - std::string gguf_digest; // Find the GGUF layer - if (manifest.contains("layers")) { - for (const auto & layer : manifest["layers"]) { - if (layer.contains("mediaType")) { - std::string media_type = layer["mediaType"].get(); - if (media_type == "application/vnd.docker.ai.gguf.v3" || - media_type.find("gguf") != std::string::npos) { - gguf_digest = layer["digest"].get(); - break; - } - } - } - } - - if (gguf_digest.empty()) { - throw std::runtime_error("No GGUF layer found in Docker manifest"); - } - - // Validate & normalize digest - gguf_digest = validate_oci_digest(gguf_digest); - LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str()); - - // Prepare local filename - std::string model_filename = repo; - std::replace(model_filename.begin(), model_filename.end(), '/', '_'); - model_filename += "_" + tag + ".gguf"; - std::string local_path = fs_get_cache_file(model_filename); - - const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { - throw std::runtime_error("Failed to download Docker Model"); - } - - LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str()); - return local_path; - } catch (const std::exception & e) { - LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what()); - throw; - } -} - // // utils // diff --git a/common/arg.h b/common/arg.h index 77997c4ef..7ab7e2cea 100644 --- a/common/arg.h +++ b/common/arg.h @@ -59,8 +59,8 @@ struct common_arg { common_arg & set_sparam(); bool in_example(enum llama_example ex); bool is_exclude(enum llama_example ex); - bool get_value_from_env(std::string & output); - bool has_value_from_env(); + bool get_value_from_env(std::string & output) const; + bool has_value_from_env() const; std::string to_string(); }; diff --git a/common/common.h b/common/common.h index 73b14cd01..87ce3c7b0 100644 --- a/common/common.h +++ b/common/common.h @@ -503,6 +503,10 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; + + bool has_speculative() const { + return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); + } }; // call once at the start of a program if it uses libcommon diff --git a/common/download.cpp b/common/download.cpp new file mode 100644 index 000000000..a08c85b2c --- /dev/null +++ b/common/download.cpp @@ -0,0 +1,1036 @@ +#include "arg.h" + +#include "common.h" +#include "gguf.h" // for reading GGUF splits +#include "log.h" +#include "download.h" + +#define JSON_ASSERT GGML_ASSERT +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(LLAMA_USE_CURL) +#include +#include +#elif defined(LLAMA_USE_HTTPLIB) +#include "http.h" +#endif + +#ifdef __linux__ +#include +#elif defined(_WIN32) +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif +#elif defined(_AIX) +#include +#else +#include +#endif +#define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// isatty +#if defined(_WIN32) +#include +#else +#include +#endif + +using json = nlohmann::ordered_json; + +// +// downloader +// + +static std::string read_file(const std::string & fname) { + std::ifstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + file.close(); + return content; +} + +static void write_file(const std::string & fname, const std::string & content) { + const std::string fname_tmp = fname + ".tmp"; + std::ofstream file(fname_tmp); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + + try { + file << content; + file.close(); + + // Makes write atomic + if (rename(fname_tmp.c_str(), fname.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str()); + // If rename fails, try to delete the temporary file + if (remove(fname_tmp.c_str()) != 0) { + LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str()); + } + } + } catch (...) { + // If anything fails, try to delete the temporary file + if (remove(fname_tmp.c_str()) != 0) { + LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str()); + } + + throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str())); + } +} + +static void write_etag(const std::string & path, const std::string & etag) { + const std::string etag_path = path + ".etag"; + write_file(etag_path, etag); + LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str()); +} + +static std::string read_etag(const std::string & path) { + std::string none; + const std::string etag_path = path + ".etag"; + + if (std::filesystem::exists(etag_path)) { + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return none; + } + std::string etag; + std::getline(etag_in, etag); + return etag; + } + + // no etag file, but maybe there is an old .json + // remove this code later + const std::string metadata_path = path + ".json"; + + if (std::filesystem::exists(metadata_path)) { + std::ifstream metadata_in(metadata_path); + try { + nlohmann::json metadata_json; + metadata_in >> metadata_json; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), + metadata_json.dump().c_str()); + if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { + std::string etag = metadata_json.at("etag"); + write_etag(path, etag); + if (!std::filesystem::remove(metadata_path)) { + LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); + } + return etag; + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + return none; +} + +#ifdef LLAMA_USE_CURL + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; + +static CURLcode common_curl_perf(CURL * curl) { + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + LOG_ERR("%s: curl_easy_perform() failed\n", __func__); + } + + return res; +} + +// Send a HEAD request to retrieve the etag and last-modified headers +struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + std::string accept_ranges; +}; + +struct FILE_deleter { + void operator()(FILE * f) const { fclose(f); } +}; + +static size_t common_header_callback(char * buffer, size_t, size_t n_items, void * userdata) { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase); + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; + } else if (std::regex_match(key, match, accept_ranges_regex)) { + headers->accept_ranges = value; + } + } + + return n_items; +} + +static size_t common_write_callback(void * data, size_t size, size_t nmemb, void * fd) { + return std::fwrite(data, size, nmemb, static_cast(fd)); +} + +// helper function to hide password in URL +static std::string llama_download_hide_password_in_url(const std::string & url) { + // Use regex to match and replace the user[:password]@ pattern in URLs + // Pattern: scheme://[user[:password]@]host[...] + static const std::regex url_regex(R"(^(?:[A-Za-z][A-Za-z0-9+.-]://)(?:[^/@]+@)?.$)"); + std::smatch match; + + if (std::regex_match(url, match, url_regex)) { + // match[1] = scheme (e.g., "https://") + // match[2] = user[:password]@ part + // match[3] = rest of URL (host and path) + return match[1].str() + "********@" + match[3].str(); + } + + return url; // No credentials found or malformed URL +} + +static void common_curl_easy_setopt_head(CURL * curl, const std::string & url) { + // Set the URL, allow to follow http redirection + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + +# if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +# endif + + curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, common_header_callback); +} + +static void common_curl_easy_setopt_get(CURL * curl) { + curl_easy_setopt(curl, CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, common_write_callback); + + // display download progress + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); +} + +static bool common_pull_file(CURL * curl, const std::string & path_temporary) { + if (std::filesystem::exists(path_temporary)) { + const std::string partial_size = std::to_string(std::filesystem::file_size(path_temporary)); + LOG_INF("%s: server supports range requests, resuming download from byte %s\n", __func__, partial_size.c_str()); + const std::string range_str = partial_size + "-"; + curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str()); + } + + // Always open file in append mode could be resuming + std::unique_ptr outfile(fopen(path_temporary.c_str(), "ab")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str()); + return false; + } + + common_curl_easy_setopt_get(curl); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile.get()); + + return common_curl_perf(curl) == CURLE_OK; +} + +static bool common_download_head(CURL * curl, + curl_slist_ptr & http_headers, + const std::string & url, + const std::string & bearer_token) { + if (!curl) { + LOG_ERR("%s: error initializing libcurl\n", __func__); + return false; + } + + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + // Check if hf-token or bearer-token was specified + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, http_headers.ptr); + common_curl_easy_setopt_head(curl, url); + return common_curl_perf(curl) == CURLE_OK; +} + +// download one single file from remote URL to local path +static bool common_download_file_single_online(const std::string & url, + const std::string & path, + const std::string & bearer_token) { + static const int max_attempts = 3; + static const int retry_delay_seconds = 2; + for (int i = 0; i < max_attempts; ++i) { + std::string etag; + + // Check if the file already exists locally + const auto file_exists = std::filesystem::exists(path); + if (file_exists) { + etag = read_etag(path); + } else { + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + + // Initialize libcurl + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + common_load_model_from_url_headers headers; + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); + curl_slist_ptr http_headers; + const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); + if (!was_perform_successful) { + head_request_ok = false; + } + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; + } + + // if head_request_ok is false, we don't have the etag or last-modified headers + // we leave should_download as-is, which is true if the file does not exist + bool should_download_from_scratch = false; + if (head_request_ok) { + // check if ETag or Last-Modified headers are different + // if it is, we need to download the file again + if (!etag.empty() && etag != headers.etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), + headers.etag.c_str()); + should_download = true; + should_download_from_scratch = true; + } + } + + const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none"; + if (should_download) { + if (file_exists && + !accept_ranges_supported) { // Resumable downloads not supported, delete and start again. + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + + const std::string path_temporary = path + ".downloadInProgress"; + if (should_download_from_scratch) { + if (std::filesystem::exists(path_temporary)) { + if (remove(path_temporary.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); + return false; + } + } + + if (std::filesystem::exists(path)) { + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + } + if (head_request_ok) { + write_etag(path, headers.etag); + } + + // start the download + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", + __func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(), + headers.etag.c_str(), headers.last_modified.c_str()); + const bool was_pull_successful = common_pull_file(curl.get(), path_temporary); + if (!was_pull_successful) { + if (i + 1 < max_attempts) { + const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; + LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } else { + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + } + + continue; + } + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); + return false; + } + + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + } else { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + } + + break; + } + + return true; +} + +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::vector res_buffer; + + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + auto data_vec = static_cast *>(data); + data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (params.timeout > 0) { + curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); + } + if (params.max_size > 0) { + curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); + } + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { + http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + std::string error_msg = curl_easy_strerror(res); + throw std::runtime_error("error: cannot make GET request: " + error_msg); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + + return { res_code, std::move(res_buffer) }; +} + +#else +#ifdef LLAMA_USE_HTTPLIB + +static bool is_output_a_tty() { +#if defined(_WIN32) + return _isatty(_fileno(stdout)); +#else + return isatty(1); +#endif +} + +static void print_progress(size_t current, size_t total) { + if (!is_output_a_tty()) { + return; + } + + if (!total) { + return; + } + + size_t width = 50; + size_t pct = (100 * current) / total; + size_t pos = (width * current) / total; + + std::cout << "[" + << std::string(pos, '=') + << (pos < width ? ">" : "") + << std::string(width - pos, ' ') + << "] " << std::setw(3) << pct << "% (" + << current / (1024 * 1024) << " MB / " + << total / (1024 * 1024) << " MB)\r"; + std::cout.flush(); +} + +static bool common_pull_file(httplib::Client & cli, + const std::string & resolve_path, + const std::string & path_tmp, + bool supports_ranges, + size_t existing_size, + size_t & total_size) { + std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app); + if (!ofs.is_open()) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str()); + return false; + } + + httplib::Headers headers; + if (supports_ranges && existing_size > 0) { + headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); + } + + std::atomic downloaded{existing_size}; + + auto res = cli.Get(resolve_path, headers, + [&](const httplib::Response &response) { + if (existing_size > 0 && response.status != 206) { + LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status); + return false; + } + if (existing_size == 0 && response.status != 200) { + LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status); + return false; + } + if (total_size == 0 && response.has_header("Content-Length")) { + try { + size_t content_length = std::stoull(response.get_header_value("Content-Length")); + total_size = existing_size + content_length; + } catch (const std::exception &e) { + LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what()); + } + } + return true; + }, + [&](const char *data, size_t len) { + ofs.write(data, len); + if (!ofs) { + LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str()); + return false; + } + downloaded += len; + print_progress(downloaded, total_size); + return true; + }, + nullptr + ); + + std::cout << "\n"; + + if (!res) { + LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); + return false; + } + + return true; +} + +// download one single file from remote URL to local path +static bool common_download_file_single_online(const std::string & url, + const std::string & path, + const std::string & bearer_token) { + static const int max_attempts = 3; + static const int retry_delay_seconds = 2; + + auto [cli, parts] = common_http_client(url); + + httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; + if (!bearer_token.empty()) { + default_headers.insert({"Authorization", "Bearer " + bearer_token}); + } + cli.set_default_headers(default_headers); + + const bool file_exists = std::filesystem::exists(path); + + std::string last_etag; + if (file_exists) { + last_etag = read_etag(path); + } else { + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + for (int i = 0; i < max_attempts; ++i) { + auto head = cli.Head(parts.path); + bool head_ok = head && head->status >= 200 && head->status < 300; + if (!head_ok) { + LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); + if (file_exists) { + LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); + return true; + } + } + + std::string etag; + if (head_ok && head->has_header("ETag")) { + etag = head->get_header_value("ETag"); + } + + size_t total_size = 0; + if (head_ok && head->has_header("Content-Length")) { + try { + total_size = std::stoull(head->get_header_value("Content-Length")); + } catch (const std::exception& e) { + LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); + } + } + + bool supports_ranges = false; + if (head_ok && head->has_header("Accept-Ranges")) { + supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + } + + bool should_download_from_scratch = false; + if (!last_etag.empty() && !etag.empty() && last_etag != etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, + last_etag.c_str(), etag.c_str()); + should_download_from_scratch = true; + } + + if (file_exists) { + if (!should_download_from_scratch) { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + return true; + } + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + + const std::string path_temporary = path + ".downloadInProgress"; + size_t existing_size = 0; + + if (std::filesystem::exists(path_temporary)) { + if (supports_ranges && !should_download_from_scratch) { + existing_size = std::filesystem::file_size(path_temporary); + } else if (remove(path_temporary.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); + return false; + } + } + + // start the download + LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", + __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); + const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); + if (!was_pull_successful) { + if (i + 1 < max_attempts) { + const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; + LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } else { + LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); + } + continue; + } + + if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + if (!etag.empty()) { + write_etag(path, etag); + } + break; + } + + return true; +} + +std::pair> common_remote_get_content(const std::string & url, + const common_remote_params & params) { + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { + size_t pos = header.find(':'); + if (pos != std::string::npos) { + headers.emplace(header.substr(0, pos), header.substr(pos + 1)); + } else { + headers.emplace(header, ""); + } + } + + if (params.timeout > 0) { + cli.set_read_timeout(params.timeout, 0); + cli.set_write_timeout(params.timeout, 0); + } + + std::vector buf; + auto res = cli.Get(parts.path, headers, + [&](const char *data, size_t len) { + buf.insert(buf.end(), data, data + len); + return params.max_size == 0 || + buf.size() <= static_cast(params.max_size); + }, + nullptr + ); + + if (!res) { + throw std::runtime_error("error: cannot make GET request"); + } + + return { res->status, std::move(buf) }; +} + +#else //no httplib + +bool common_has_curl() { + return false; +} + +static bool common_download_file_single_online(const std::string &, const std::string &, const std::string &) { + LOG_ERR("error: built without CURL, cannot download model from internet\n"); + return false; +} + +std::pair> common_remote_get_content(const std::string & url, const common_remote_params &) { + if (!url.empty()) { + throw std::runtime_error("error: built without CURL, cannot download model from the internet"); + } + + return {}; +} +#endif + + +#endif // LLAMA_USE_CURL + +static bool common_download_file_single(const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline) { + if (!offline) { + return common_download_file_single_online(url, path, bearer_token); + } + + if (!std::filesystem::exists(path)) { + LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + + LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + return true; +} + +// download multiple files from remote URLs to local paths +// the input is a vector of pairs +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { + // Prepare download in parallel + std::vector> futures_download; + for (auto const & item : urls) { + futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline); + }, item)); + } + + // Wait for all downloads to complete + for (auto & f : futures_download) { + if (!f.get()) { + return false; + } + } + + return true; +} + +bool common_download_model( + const common_params_model & model, + const std::string & bearer_token, + bool offline) { + // Basic validation of the model.url + if (model.url.empty()) { + LOG_ERR("%s: invalid model url\n", __func__); + return false; + } + + if (!common_download_file_single(model.url, model.path, bearer_token, offline)) { + return false; + } + + // check for additional GGUFs split to download + int n_split = 0; + { + struct gguf_init_params gguf_params = { + /*.no_alloc = */ true, + /*.ctx = */ NULL, + }; + auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params); + if (!ctx_gguf) { + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str()); + return false; + } + + auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT); + if (key_n_split >= 0) { + n_split = gguf_get_val_u16(ctx_gguf, key_n_split); + } + + gguf_free(ctx_gguf); + } + + if (n_split > 1) { + char split_prefix[PATH_MAX] = {0}; + char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0}; + + // Verify the first split file format + // and extract split URL and PATH prefixes + { + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split); + return false; + } + + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split); + return false; + } + } + + std::vector> urls; + for (int idx = 1; idx < n_split; idx++) { + char split_path[PATH_MAX] = {0}; + llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split); + + char split_url[LLAMA_MAX_URL_LENGTH] = {0}; + llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split); + + if (std::string(split_path) == model.path) { + continue; // skip the already downloaded file + } + + urls.push_back({split_url, split_path}); + } + + // Download in parallel + common_download_file_multiple(urls, bearer_token, offline); + } + + return true; +} + +common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; + + // headers + std::vector headers; + headers.push_back("Accept: application/json"); + if (!bearer_token.empty()) { + headers.push_back("Authorization: Bearer " + bearer_token); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + // User-Agent header is already set in common_remote_get_content, no need to set it here + + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; + string_replace_all(cached_response_fname, "/", "_"); + std::string cached_response_path = fs_get_cache_file(cached_response_fname); + + // make the request + common_remote_params params; + params.headers = headers; + long res_code = 0; + std::string res_str; + bool use_cache = false; + if (!offline) { + try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + } + } + if (res_code == 0) { + if (std::filesystem::exists(cached_response_path)) { + LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); + res_str = read_file(cached_response_path); + res_code = 200; + use_cache = true; + } else { + throw std::runtime_error( + offline ? "error: failed to get manifest (offline mode)" + : "error: failed to get manifest (check your internet connection)"); + } + } + std::string ggufFile; + std::string mmprojFile; + + if (res_code == 200 || res_code == 304) { + try { + auto j = json::parse(res_str); + + if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) { + ggufFile = j["ggufFile"]["rfilename"].get(); + } + if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) { + mmprojFile = j["mmprojFile"]["rfilename"].get(); + } + } catch (const std::exception & e) { + throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what()); + } + if (!use_cache) { + // if not using cached response, update the cache file + write_file(cached_response_path, res_str); + } + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (ggufFile.empty()) { + throw std::runtime_error("error: model does not have ggufFile"); + } + + return { hf_repo, ggufFile, mmprojFile }; +} + +// +// Docker registry functions +// + +static std::string common_docker_get_token(const std::string & repo) { + std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; + + common_remote_params params; + auto res = common_remote_get_content(url, params); + + if (res.first != 200) { + throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); + } + + std::string response_str(res.second.begin(), res.second.end()); + nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); + + if (!response.contains("token")) { + throw std::runtime_error("Docker registry token response missing 'token' field"); + } + + return response["token"].get(); +} + +std::string common_docker_resolve_model(const std::string & docker) { + // Parse ai/smollm2:135M-Q4_0 + size_t colon_pos = docker.find(':'); + std::string repo, tag; + if (colon_pos != std::string::npos) { + repo = docker.substr(0, colon_pos); + tag = docker.substr(colon_pos + 1); + } else { + repo = docker; + tag = "latest"; + } + + // ai/ is the default + size_t slash_pos = docker.find('/'); + if (slash_pos == std::string::npos) { + repo.insert(0, "ai/"); + } + + LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str()); + try { + // --- helper: digest validation --- + auto validate_oci_digest = [](const std::string & digest) -> std::string { + // Expected: algo:hex ; start with sha256 (64 hex chars) + // You can extend this map if supporting other algorithms in future. + static const std::regex re("^sha256:([a-fA-F0-9]{64})$"); + std::smatch m; + if (!std::regex_match(digest, m, re)) { + throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest); + } + // normalize hex to lowercase + std::string normalized = digest; + std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){ + return std::tolower(c); + }); + return normalized; + }; + + std::string token = common_docker_get_token(repo); // Get authentication token + + // Get manifest + const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; + std::string manifest_url = url_prefix + "/manifests/" + tag; + common_remote_params manifest_params; + manifest_params.headers.push_back("Authorization: Bearer " + token); + manifest_params.headers.push_back( + "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + auto manifest_res = common_remote_get_content(manifest_url, manifest_params); + if (manifest_res.first != 200) { + throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); + } + + std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); + nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); + std::string gguf_digest; // Find the GGUF layer + if (manifest.contains("layers")) { + for (const auto & layer : manifest["layers"]) { + if (layer.contains("mediaType")) { + std::string media_type = layer["mediaType"].get(); + if (media_type == "application/vnd.docker.ai.gguf.v3" || + media_type.find("gguf") != std::string::npos) { + gguf_digest = layer["digest"].get(); + break; + } + } + } + } + + if (gguf_digest.empty()) { + throw std::runtime_error("No GGUF layer found in Docker manifest"); + } + + // Validate & normalize digest + gguf_digest = validate_oci_digest(gguf_digest); + LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str()); + + // Prepare local filename + std::string model_filename = repo; + std::replace(model_filename.begin(), model_filename.end(), '/', '_'); + model_filename += "_" + tag + ".gguf"; + std::string local_path = fs_get_cache_file(model_filename); + + const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; + if (!common_download_file_single(blob_url, local_path, token, false)) { + throw std::runtime_error("Failed to download Docker Model"); + } + + LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str()); + return local_path; + } catch (const std::exception & e) { + LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what()); + throw; + } +} diff --git a/common/download.h b/common/download.h new file mode 100644 index 000000000..ddf36155e --- /dev/null +++ b/common/download.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +struct common_params_model; + +// +// download functionalities +// + +struct common_hf_file_res { + std::string repo; // repo name with ":tag" removed + std::string ggufFile; + std::string mmprojFile; +}; + +// resolve and download model from Docker registry +// return local path to downloaded model file +std::string common_docker_resolve_model(const std::string & docker); + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +common_hf_file_res common_get_hf_file( + const std::string & hf_repo_with_tag, + const std::string & bearer_token, + bool offline); + +// returns true if download succeeded +bool common_download_model( + const common_params_model & model, + const std::string & bearer_token, + bool offline); diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ee41a3502..ae0ebb3ca 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -580,16 +580,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); uint8_t *patmp = atmp; int vsums; - int tmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; __asm__ __volatile__( "vsetivli zero, 16, e8, m1\n\t" "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" "vand.vi v0, v1, 0xF\n\t" "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" "vse8.v v0, (%[scale])\n\t" "vsetivli zero, 16, e16, m2\n\t" - "vle16.v v2, (%[bsums])\n\t" "vzext.vf2 v0, v1\n\t" "vwmul.vv v4, v0, v2\n\t" "vsetivli zero, 16, e32, m4\n\t" @@ -608,46 +611,89 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi for (int j = 0; j < QK_K/128; ++j) { __asm__ __volatile__( - "vsetvli zero, %[vl32], e8, m2\n\t" + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" "vsrl.vi v4, v0, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" + "addi %[tmp], %[q8], 32\n\t" "vle8.v v8, (%[q8])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vle8.v v15, (%[scale])\n\t" - "vzext.vf4 v12, v15\n\t" - "vmul.vv v10, v10, v12\n\t" - "vredsum.vs v0, v10, v0\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" "vmv.x.s %[tmp], v0\n\t" - "add %[isum], %[isum], %[tmp]" - : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) : "memory" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" @@ -929,7 +975,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int8_t * restrict q8 = y[i].qs; int8_t * scale = (int8_t *)utmp; - int tmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; __asm__ __volatile__( "vsetivli zero, 12, e8, m1\n\t" "vle8.v v0, (%[s6b])\n\t" @@ -967,19 +1013,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int isum = 0; for (int j = 0; j < QK_K; j += 128) { __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" "vle8.v v8, (%[q3])\n\t" "vsrl.vi v10, v8, 2\n\t" "vsrl.vi v12, v8, 4\n\t" "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" "vand.vi v8, v8, 3\n\t" "vand.vi v10, v10, 3\n\t" "vand.vi v12, v12, 3\n\t" "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" "vand.vx v4, v2, %[m]\n\t" "slli %[m], %[m], 1\n\t" "vmseq.vx v0, v4, zero\n\t" "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" "vand.vx v4, v2, %[m]\n\t" "slli %[m], %[m], 1\n\t" "vmseq.vx v0, v4, zero\n\t" @@ -994,34 +1044,43 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi "vadd.vi v14, v14, -4, v0.t\n\t" "vsetvli zero, %[vl128], e8, m8\n\t" "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" "vsetvli zero, %[vl64], e8, m4\n\t" "vwmul.vv v16, v0, v8\n\t" "vwmul.vv v24, v4, v12\n\t" "vsetivli zero, 16, e16, m2\n\t" "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vle8.v v15, (%[scale])\n\t" - "vsext.vf4 v12, v15\n\t" - "vmul.vv v10, v10, v12\n\t" - "vredsum.vs v0, v10, v0\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" "vmv.x.s %[tmp], v0\n\t" - "add %[isum], %[isum], %[tmp]" - : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) : "memory" diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c5821acbd..50612237c 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -7,6 +7,10 @@ typedef void (*cpy_kernel_t)(const char * cx, char * cdst); +const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks +const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available +const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows + template static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { + + const T* src = reinterpret_cast(cx); + T* dst = reinterpret_cast(cdst); + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; + const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset + const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + + __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + +#pragma unroll + for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { + + const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; + if (imat >= nmat) + break; + +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x * sizeof(float)/sizeof(T); + T *tile2 = reinterpret_cast(tile[row]); + tile2[col] = src[imat*n + (y+j)*ne01 + x]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { + if (ty + j < ne01 && tx < ne00) { + const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T); + const T *tile2 = reinterpret_cast(tile[threadIdx.x]); + dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; + } + } + } +} + static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -136,15 +189,36 @@ cudaStream_t stream) { (cx, cdst, ne); } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + if (transposed) { + GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed + int ne00n, ne01n, ne02n; + if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here + ne00n = ne00; + ne01n = ne01; + ne02n = ne02; + } else if (nb00 > nb02) { + ne00n = ne00; + ne01n = ne01*ne02; + ne02n = 1; + } + + dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, + (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); + cpy_flt_transpose<<>> + (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } else { + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } } static void ggml_cpy_f32_q8_0_cuda( @@ -310,6 +384,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); @@ -322,7 +397,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (can_be_transposed) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); @@ -361,7 +440,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (can_be_transposed) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); @@ -375,7 +458,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (can_be_transposed) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { if (contiguous_srcs) { ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 074e03d4a..941d34da0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -29,7 +29,6 @@ bool g_mul_mat_q = true; #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmvf.cuh" #include "ggml-cuda/mmvq.cuh" -#include "ggml-cuda/moe-expert-reduce.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" #include "ggml-cuda/opt-step-sgd.cuh" @@ -2114,7 +2113,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) || ggml_backend_buft_is_cuda_split(src1->buffer->buft); @@ -2208,16 +2207,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); - use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2296,7 +2295,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * return; } - if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) { + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) { ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); return; } @@ -3165,8 +3164,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - - #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; prev_i = i; @@ -3212,31 +3209,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } - if (node->op == GGML_OP_MUL) { - int current_node = i + 1; - int num_views = 0; - int num_adds = 0; - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - num_views++; - current_node++; - } - - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD && - num_adds < num_views - 1) { - num_adds++; - current_node++; - } - - if (num_adds == num_views - 1 && num_views > 0) { - ggml_tensor * dst_node = cgraph->nodes[current_node - 1]; - if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) { - ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node); - i += num_views + num_adds; - continue; - } - } - } - if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; @@ -3315,6 +3287,13 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && + (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + const ggml_tensor * src0 = up_n->src[0]; const ggml_tensor * src1 = up_n->src[1]; const ggml_tensor * ids = up_n->src[2]; @@ -3424,6 +3403,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + ggml_cuda_mm_fusion_args_host fusion_data{}; fusion_data.x_bias = bias_tensor; diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 2b0a61395..153dd5a97 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -119,15 +119,27 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) { - +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, + const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) { if (ggml_is_quantized(type)) { return false; } - if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { + const size_t ts = ggml_type_size(type); + if (src0_ne[0] % (warp_size * (4/ts)) != 0) { return false; } + + if (src0_nb[0] != ts) { + return false; + } + + // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash: + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index f7e46e2f6..45724e091 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -17,7 +17,7 @@ struct mmf_ids_data { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b2cbfdb77..1094be110 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3495,7 +3495,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int col_diff = col_high - col_low; for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { - ids_dst_shared[j] = ids_dst[col_low + j]; + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } __syncthreads(); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 4e3178343..6238ce7eb 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -716,10 +716,23 @@ void ggml_cuda_op_mul_mat_vec_f( GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); } -bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) { if (src0_ne[0] % 2 != 0) { return false; } + + const size_t ts = ggml_type_size(type); + if (src0_nb[0] != ts) { + return false; + } + + // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash: + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } + switch (type) { case GGML_TYPE_F32: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a205aa8e4..a09fbdc72 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -9,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11); diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cu b/ggml/src/ggml-cuda/moe-expert-reduce.cu deleted file mode 100644 index a97c5d573..000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cu +++ /dev/null @@ -1,168 +0,0 @@ -#include "moe-expert-reduce.cuh" - -// This kernel is a fusion of the expert weight reduce, common in MoE models - -template -__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts, - const float * __restrict__ weights, - float * __restrict__ dst, - const int n_expert_used, - const int n_cols) { - const int row = blockIdx.x; - const int col = blockIdx.y * blockDim.x + threadIdx.x; - if (col >= n_cols) { - return; - } - - experts += row * n_cols * n_expert_used; - weights += row * n_expert_used; - dst += row * n_cols; - - float acc = 0.f; - if constexpr (n_expert_used_template == 0) { - for (int expert = 0; expert < n_expert_used; ++expert) { - ggml_cuda_mad(acc, experts[col], weights[expert]); - experts += n_cols; - } - dst[col] = acc; - } else { -#pragma unroll - for (int i = 0; i < n_expert_used_template; ++i) { - ggml_cuda_mad(acc, experts[col], weights[i]); - experts += n_cols; - } - dst[col] = acc; - } -} - -static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const float * experts, - const float * weights, - float * dst, - const int n_expert_used, - const int n_cols, - const int n_rows) { - const int block_size = 32; - - const int n_blocks_x = n_rows; - const int n_blocks_y = (n_cols + block_size - 1) / block_size; - - dim3 block_dims(block_size); - dim3 grid_dims(n_blocks_x, n_blocks_y); - - cudaStream_t stream = ctx.stream(); - switch (n_expert_used) { - case 1: - moe_expert_reduce_cuda<1> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 2: - moe_expert_reduce_cuda<2> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 4: - moe_expert_reduce_cuda<4> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 6: - moe_expert_reduce_cuda<6> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 8: - moe_expert_reduce_cuda<8> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 16: - moe_expert_reduce_cuda<16> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 32: - moe_expert_reduce_cuda<32> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 64: - moe_expert_reduce_cuda<64> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 128: - moe_expert_reduce_cuda<128> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - default: - moe_expert_reduce_cuda<0> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - } -} - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) { - const ggml_tensor * mul = cgraph->nodes[start_index]; - - if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) { - return false; - } - - int current_node = start_index + 1; - size_t current_offset = 0; - - std::vector view_nodes; - //check if all are views of the expert in increasing order - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - const ggml_tensor * node = cgraph->nodes[current_node]; - if (node->view_src != mul) { - return false; - } - if (node->view_offs < current_offset) { - return false; - } - current_offset = node->view_offs; - current_node++; - view_nodes.push_back(node); - } - - //check if all the adds are in increasing order - const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0]; - int num_adds = 0; - int num_views = view_nodes.size(); - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) { - const ggml_tensor * add_node = cgraph->nodes[current_node]; - - bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false; - bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false; - - if (!is_first_op_ok || !is_second_op_ok) { - return false; - } - prev_add_src = add_node; - - num_adds++; - current_node++; - } - - if (num_views != num_adds + 1) { - return false; - } - - return true; -} - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst) { - const int n_rows = experts->ne[2]; - const int n_expert_used = experts->ne[1]; - const int n_cols = experts->ne[0]; - - GGML_ASSERT(experts->type == GGML_TYPE_F32); - GGML_ASSERT(weights->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(experts)); - GGML_ASSERT(ggml_is_contiguous(weights)); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const float * experts_d = (const float *) experts->data; - const float * weights_d = (const float *) weights->data; - float * dst_d = (float *) dst->data; - - launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows); -} diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cuh b/ggml/src/ggml-cuda/moe-expert-reduce.cuh deleted file mode 100644 index cafc50e10..000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cuh +++ /dev/null @@ -1,11 +0,0 @@ -#include "common.cuh" -#include "ggml.h" - -#include - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst); - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 945652263..7064b7486 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -367,7 +367,13 @@ struct ggml_backend_hexagon_buffer_context { ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { size += 4 * 1024; // extra page for padding - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + if (rpcmem_alloc2) { + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + } else { + GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str()); + this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + } + if (!this->base) { GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); @@ -1679,12 +1685,13 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } // Get session URI - char htp_uri[256]; - sprintf(htp_uri, "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch); char session_uri[256]; { - struct remote_rpc_get_uri u; + char htp_uri[256]; + snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch); + + struct remote_rpc_get_uri u = {}; u.session_id = this->session_id; u.domain_name = const_cast(CDSP_DOMAIN_NAME); u.domain_name_len = strlen(CDSP_DOMAIN_NAME); @@ -1695,8 +1702,12 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u)); if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: failed to get URI for session %d : error 0x%x\n", dev_id, err); - throw std::runtime_error("ggml-hex: remote_session_control(get-uri) failed (see log for details)"); + // fallback to single session uris + int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN; + + snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri); + + GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri); } } @@ -3668,6 +3679,11 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { } } + if(opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); + } + GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h index 66f9fd373..1a48f5dcb 100644 --- a/ggml/src/ggml-hexagon/htp-utils.h +++ b/ggml/src/ggml-hexagon/htp-utils.h @@ -64,6 +64,7 @@ extern "C" { # pragma weak remote_handle64_control # pragma weak fastrpc_mmap # pragma weak fastrpc_munmap +# pragma weak rpcmem_alloc2 #endif #if !defined(_WINDOWS) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 052efb7ac..b8d35b78a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -35,7 +35,6 @@ struct ggml_metal { // additional, inference-time compiled pipelines ggml_metal_pipelines_t pipelines_ext; - bool use_bfloat; bool use_fusion; bool use_concurrency; bool use_graph_optimize; @@ -121,11 +120,10 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { } } - const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - res->use_bfloat = props_dev->has_bfloat; res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; @@ -147,7 +145,6 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt)); - GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false"); GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4d5829748..cb27dca98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -95,7 +95,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder); typedef struct ggml_metal_library * ggml_metal_library_t; -ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev); +ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev); +ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose); + void ggml_metal_library_free(ggml_metal_library_t lib); ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); @@ -193,6 +195,7 @@ struct ggml_metal_device_props { bool has_simdgroup_mm; bool has_unified_memory; bool has_bfloat; + bool has_tensor; bool use_residency_sets; bool use_shared_buffers; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 5a0e3955c..0ec43f594 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -21,8 +21,9 @@ #define GGML_METAL_HAS_RESIDENCY_SETS 1 #endif -// overload of MTLGPUFamilyMetal3 (not available in some environments) +// overload of MTLGPUFamilyMetalX (not available in some environments) static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; +static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; // virtual address for GPU memory allocations static atomic_uintptr_t g_addr_device = 0x000000400ULL; @@ -261,6 +262,10 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"]; } + if (ggml_metal_device_get_props(dev)->has_tensor) { + [prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"]; + } + #if GGML_METAL_EMBED_LIBRARY [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; #endif @@ -298,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { return res; } +ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) { + if (source == NULL) { + GGML_LOG_ERROR("%s: source is NULL\n", __func__); + return NULL; + } + + id device = ggml_metal_device_get_obj(dev); + id library = nil; + NSError * error = nil; + + const int64_t t_start = ggml_time_us(); + + NSString * src = [[NSString alloc] initWithBytes:source + length:strlen(source) + encoding:NSUTF8StringEncoding]; + if (!src) { + GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__); + return NULL; + } + + @autoreleasepool { + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + library = [device newLibraryWithSource:src options:options error:&error]; + if (error) { + if (verbose) { + GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]); + } else { + GGML_LOG_ERROR("%s: error compiling source\n", __func__); + } + library = nil; + } + + [options release]; + } + + [src release]; + + if (!library) { + if (verbose) { + GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__); + } + + return NULL; + } + + if (verbose) { + GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + } + + ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); + if (!res) { + GGML_LOG_ERROR("%s: calloc failed\n", __func__); + return NULL; + } + + res->obj = library; + res->device = device; + res->pipelines = ggml_metal_pipelines_init(); + + return res; +} + void ggml_metal_library_free(ggml_metal_library_t lib) { if (!lib) { return; @@ -345,9 +416,9 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l if (!mtl_function) { ggml_critical_section_end(); - GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); + GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); if (error) { - GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); } return nil; @@ -355,13 +426,21 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; - ggml_metal_pipelines_add(lib->pipelines, name, res); - [mtl_function release]; GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj, (int) res->obj.maxTotalThreadsPerThreadgroup, (int) res->obj.threadExecutionWidth); + + if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) { + ggml_critical_section_end(); + + GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); + + return nil; + } + + ggml_metal_pipelines_add(lib->pipelines, name, res); } ggml_critical_section_end(); @@ -469,6 +548,126 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6]; + if (getenv("GGML_METAL_BF16_DISABLE") != NULL) { + dev->props.has_bfloat = false; + } + + dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML]; + if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) { + dev->props.has_tensor = false; + } + + // note: disable the tensor API by default for old chips because with the current implementation it is not useful + // - M2 Ultra: ~5% slower + // - M4, M4 Max: no significant difference + // + // TODO: try to update the tensor API kernels to at least match the simdgroup performance + if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL && + ![[dev->mtl_device name] containsString:@"M5"] && + ![[dev->mtl_device name] containsString:@"M6"]) { + GGML_LOG_WARN("%s: tensor API disabled for pre-M5 device\n", __func__); + dev->props.has_tensor = false; + } + + // double-check that the tensor API compiles + if (dev->props.has_tensor) { + const char * src_tensor_f16 = "\n" + "#include \n" + "#include \n" + "#include \n" + " \n" + "using namespace metal; \n" + "using namespace mpp::tensor_ops; \n" + " \n" + "kernel void dummy_kernel( \n" + " tensor> A [[buffer(0)]], \n" + " tensor> B [[buffer(1)]], \n" + " device float * C [[buffer(2)]], \n" + " uint2 tgid [[threadgroup_position_in_grid]]) \n" + "{ \n" + " auto tA = A.slice(0, (int)tgid.y); \n" + " auto tB = B.slice((int)tgid.x, 0); \n" + " \n" + " matmul2d< \n" + " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " execution_simdgroups<4>> mm; \n" + " \n" + " auto cT = mm.get_destination_cooperative_tensor(); \n" + " \n" + " auto sA = tA.slice(0, 0); \n" + " auto sB = tB.slice(0, 0); \n" + " mm.run(sB, sA, cT); \n" + " \n" + " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " \n" + " cT.store(tC); \n" + "}"; + + GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__); + ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false); + if (lib == NULL) { + GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); + dev->props.has_tensor = false; + } else { + ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl) { + GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); + dev->props.has_tensor = false; + } + + ggml_metal_library_free(lib); + } + } + + // try to compile a dummy kernel to determine if the tensor API is supported for bfloat + if (dev->props.has_tensor && dev->props.has_bfloat) { + const char * src_tensor_bf16 = "\n" + "#include \n" + "#include \n" + "#include \n" + " \n" + "using namespace metal; \n" + "using namespace mpp::tensor_ops; \n" + " \n" + "kernel void dummy_kernel( \n" + " tensor> A [[buffer(0)]], \n" + " tensor> B [[buffer(1)]], \n" + " device float * C [[buffer(2)]], \n" + " uint2 tgid [[threadgroup_position_in_grid]]) \n" + "{ \n" + " auto tA = A.slice(0, (int)tgid.y); \n" + " auto tB = B.slice((int)tgid.x, 0); \n" + " \n" + " matmul2d< \n" + " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " execution_simdgroups<4>> mm; \n" + " \n" + " auto cT = mm.get_destination_cooperative_tensor(); \n" + " \n" + " auto sA = tA.slice(0, 0); \n" + " auto sB = tB.slice(0, 0); \n" + " mm.run(sB, sA, cT); \n" + " \n" + " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " \n" + " cT.store(tC); \n" + "}"; + + GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__); + ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false); + if (lib == NULL) { + GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); + dev->props.has_bfloat = false; + } else { + ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); + if (!ppl) { + GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); + dev->props.has_bfloat = false; + } + + ggml_metal_library_free(lib); + } + } dev->props.use_residency_sets = true; #if defined(GGML_METAL_HAS_RESIDENCY_SETS) @@ -476,7 +675,6 @@ ggml_metal_device_t ggml_metal_device_init(void) { #endif dev->props.use_shared_buffers = dev->props.has_unified_memory; - if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { dev->props.use_shared_buffers = false; } @@ -529,6 +727,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false"); GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false"); GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false"); GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false"); GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false"); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 424c400f2..cea535ade 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9,6 +9,12 @@ __embed_ggml-common.h__ #include +#ifdef GGML_METAL_HAS_TENSOR +#include + +#include +#endif + using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) @@ -1742,7 +1748,7 @@ kernel void kernel_op_sum_f32( float sumf = 0; - for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { sumf += src0[i0]; } @@ -5467,6 +5473,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at #undef FA_TYPES #undef FA_TYPES_BF +#undef FA_TYPES_F32 constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; @@ -6088,6 +6095,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES +#undef FA_TYPES_F32 constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; @@ -8141,17 +8149,6 @@ kernel void kernel_set_rows_f( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - // each block_q contains 16*nl weights template kernel void kernel_mul_mm( @@ -8167,18 +8164,48 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); - const int r0 = tgpig.y; - const int r1 = tgpig.x; + threadgroup float * sc = (threadgroup float *)(shmem); + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + const int im = tgpig.z; + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; // if this block is of 64x32 shape or smaller - const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1; // a thread shouldn't load data outside of the matrix - const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63 + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31 + const short il0 = (tiitg % NL0); + + short il = il0; + + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il0/nl; + + device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 + lr1) + + args.nb10*iy); + +#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -8187,36 +8214,36 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } +#else + auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); + auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - short il = (tiitg % THREAD_PER_ROW); + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + auto cT = mm.get_destination_cooperative_tensor(); +#endif - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const short offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 - + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - - const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); - - device const T1 * y = (device const T1 *)(src1 - + args.nb13*i13 - + args.nb12*i12 - + args.nb11*(r1*BLOCK_SIZE_N + thread_col) - + args.nb10*iy); - - for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { +#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); // no need for dequantization for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + //const short lx = i%8; + //const short ly = (tiitg/NL0)%8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; } } else { S0_4x4 temp_a; @@ -8225,91 +8252,203 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); FOR_UNROLL (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + //const short lx = i%8; + //const short ly = (tiitg/NL0)%8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + // NOTE: this is massively slower.. WTF? + //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4]; + + *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; } } if (FC_mul_mm_bc_inp) { for (short i = 0; i < 8; ++i) { - sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + const short ib = 4*sx + sy; + + *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; } } else { - *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short dx = sx; + const short dy = sy; + + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } +#else + // load data and store to threadgroup memory + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // no need for dequantization + for (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + //const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); + } +#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; - y += BLOCK_SIZE_K; + + y += NK; threadgroup_barrier(mem_flags::mem_threadgroup); +#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products - threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); - #pragma unroll(4) - for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(4) - for (short i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); - } - - #pragma unroll(2) - for (short i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); } simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(8) - for (short i = 0; i < 8; i++){ + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } - lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; - lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; + lsma += 8*64; + lsmb += 4*64; } +#else + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mm.run(sB, sA, cT); +#endif } - if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) { + if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory +#ifdef GGML_METAL_HAS_TENSOR device float * C = (device float *) dst + - (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ - (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + r0 + \ + r1 * args.ne0 + im*args.ne1*args.ne0; + + auto tC = tensor, tensor_inline>(C, dextents(args.ne0, NR1)); + cT.store(tC); +#else + device float * C = (device float *) dst + + (r0 + 32*(sgitg & 1)) + \ + (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); + simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } +#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + + threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; + +#ifdef GGML_METAL_HAS_TENSOR + auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); + cT.store(tC); +#else for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } +#endif threadgroup_barrier(mem_flags::mem_threadgroup); if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; + for (int j = tiitg; j < nr1; j += NR1) { + device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; - threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float * C = temp_str + (j*NR0); threadgroup float4 * C4 = (threadgroup float4 *) C; int i = 0; - for (; i < n_rows/4; i++) { + for (; i < nr0/4; i++) { *(D4 + i) = *(C4 + i); } i *= 4; - for (; i < n_rows; i++) { + for (; i < nr0; i++) { *(D + i) = *(C + i); } } @@ -8394,31 +8533,63 @@ kernel void kernel_mul_mm_id( ushort tiitg[[thread_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); - const int r0 = tgpig.y; - const int r1 = tgpig.x; + threadgroup float * sc = (threadgroup float *)(shmem); + + constexpr int NR0 = 64; + constexpr int NR1 = 32; + + constexpr int NK = 32; + constexpr int NL0 = NK/16; + constexpr int NL1 = NK/8; + const int im = tgpig.z; // expert + const int r0 = tgpig.y*NR0; + const int r1 = tgpig.x*NR1; device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); device const int32_t * ids_i32 = (device const int32_t *) (hids); const int32_t neh1 = tpe_u32[im]; - if (r1*BLOCK_SIZE_N >= neh1) { + if (r1 >= neh1) { return; } // if this block is of 64x32 shape or smaller - const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; + const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; // a thread shouldn't load data outside of the matrix - const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63 + const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31 + const short il0 = (tiitg % NL0); + + short il = il0; + + const int id = ids_i32[im*args.ne21 + r1 + lr1]; + + const short i11 = (id % args.ne20) % args.ne11; + const short i12 = (id / args.ne20); + const short i13 = 0; + + const uint64_t offset0 = im*args.nb02 + i13*args.nb03; + const short offset1 = il0/nl; + + device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; + + const short iy = 8*(tiitg % NL1); + + device const T1 * y = (device const T1 *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*i11 + + args.nb10*iy); + +#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -8427,39 +8598,36 @@ kernel void kernel_mul_mm_id( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } +#else + auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); + auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - short il = (tiitg % THREAD_PER_ROW); + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mm; - const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col]; + auto cT = mm.get_destination_cooperative_tensor(); +#endif - const short i11 = (id % args.ne20) % args.ne11; - const short i12 = (id / args.ne20); - const short i13 = 0; - - const uint64_t offset0 = im*args.nb02 + i13*args.nb03; - const short offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 - + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - - const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); - - device const T1 * y = (device const T1 *)(src1 - + args.nb13*i13 - + args.nb12*i12 - + args.nb11*i11 - + args.nb10*iy); - - for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { +#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); // no need for dequantization for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + //const short lx = i%8; + //const short ly = (tiitg/NL0)%8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; } } else { S0_4x4 temp_a; @@ -8468,85 +8636,188 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); FOR_UNROLL (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + //const short lx = i%8; + //const short ly = (tiitg/NL0)%8; + const short lx = (tiitg/NL0)%8; + const short ly = i%8; + + const short ib = 8*sx + sy; + + // NOTE: this is massively slower.. WTF? + //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4]; + + *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; } } if (FC_mul_mm_bc_inp) { for (short i = 0; i < 8; ++i) { - sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + const short ib = 4*sx + sy; + + *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; } } else { - *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short dx = sx; + const short dy = sy; + + const short ly = (tiitg/NL1)%8; + + const short ib = 4*sx + sy; + + *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } +#else + // load data and store to threadgroup memory + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // no need for dequantization + for (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } + } + + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + } else { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + //const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); + } +#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; - y += BLOCK_SIZE_K; + + y += NK; threadgroup_barrier(mem_flags::mem_threadgroup); +#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products - threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); - #pragma unroll(4) - for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { - #pragma unroll(4) - for (short i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); } simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (short i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); } - #pragma unroll(8) - for (short i = 0; i < 8; i++){ + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } - lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; - lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; + lsma += 8*64; + lsmb += 4*64; } +#else + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mm.run(sB, sA, cT); +#endif } + // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; +#ifdef GGML_METAL_HAS_TENSOR + auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); + cT.store(tC); +#else + threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; - #pragma unroll(8) for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } +#endif threadgroup_barrier(mem_flags::mem_threadgroup); - for (short j = sgitg; j < n_cols; j += 4) { - const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j]; + for (short j = sgitg; j < nr1; j += 4) { + const int id = ids_i32[im*args.ne21 + r1 + j]; const short ide = id % args.ne20; const short idt = id / args.ne20; - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0; + device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; - threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M); + threadgroup float * C = (threadgroup float *) shmem + j*NR0; threadgroup float4 * C4 = (threadgroup float4 *) C; int i = tiisg; - for (; i < n_rows/4; i += 32) { + for (; i < nr0/4; i += 32) { *(D4 + i) = *(C4 + i); } - i = (4*(n_rows/4)) + tiisg; - for (; i < n_rows; i += 32) { + i = (4*(nr0/4)) + tiisg; + for (; i < nr0; i += 32) { *(D + i) = *(C + i); } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index dc03b5828..13f8c3070 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -146,9 +146,9 @@ struct vk_pipeline_struct { // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; // set to true to request the pipeline is compiled - bool needed {}; + std::atomic needed {}; // set to true when the shader has been compiled - bool compiled {}; + std::atomic compiled {}; // number of registers used, extracted from pipeline executable properties uint32_t register_count {}; }; @@ -482,6 +482,14 @@ static constexpr std::initializer_list> rope_view_set_rows_ed { 2, 0, 1 }, // set_rows->src[0] == view }; +static constexpr std::initializer_list> rms_norm_mul_rope_view_set_rows_edges { + { 1, 0, 0 }, // mul->src[0] == rms + { 2, 0, 1 }, // rope->src[0] == mul + { 3, 0, 2 }, // view->src[0] == rope + { 4, 0, 3 }, // set_rows->src[0] == view +}; + + struct vk_device_struct { std::recursive_mutex mutex; @@ -633,6 +641,8 @@ struct vk_device_struct { vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_partials_f32; vk_pipeline pipeline_rms_norm_mul_partials_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f16; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; @@ -1076,6 +1086,7 @@ struct vk_op_diag_mask_push_constants { }; struct vk_op_rope_push_constants { + uint32_t rope_mode; uint32_t ncols; uint32_t n_dims; float freq_scale; @@ -1095,6 +1106,12 @@ struct vk_op_rope_push_constants { uint32_t set_rows_stride; }; +// For fused rms_norm+mul+rope(+view+set_rows) +struct vk_op_rms_norm_mul_rope_push_constants { + vk_op_binary_push_constants bin; + vk_op_rope_push_constants rope; +}; + struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; @@ -1858,10 +1875,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - { - std::lock_guard guard(device->mutex); - device->all_pipelines.push_back(pipeline); - } + device->all_pipelines.push_back(pipeline); { std::lock_guard guard(compile_count_mutex); @@ -2552,6 +2566,7 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + std::lock_guard guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -2745,6 +2760,8 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!pipeline->needed || pipeline->compiled) { return; } + // TODO: We're no longer benefitting from the async compiles (shaders are + // compiled individually, as needed) and this complexity can be removed. { // wait until fewer than N compiles are in progress uint32_t N = std::max(1u, std::thread::hardware_concurrency()); @@ -3573,6 +3590,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + if (device->float_controls_rte_fp16 && + sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + } + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -5417,7 +5440,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { device->pinned_memory.erase(device->pinned_memory.begin() + index); } -static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { +static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { std::lock_guard guard(device->mutex); buf = nullptr; buf_offset = 0; @@ -5432,6 +5455,32 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf } } +static vk_subbuffer ggml_vk_tensor_subbuffer( + const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) { + + vk_buffer buffer = nullptr; + size_t offset = 0; + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensor->data, buffer, offset); + } + if (!buffer) { + auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + buffer = buf_ctx->dev_buffer; + offset = vk_tensor_offset(tensor) + tensor->view_offs; + } + GGML_ASSERT(buffer != nullptr); + + size_t size = ggml_nbytes(tensor); + + size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + // The shader must support misaligned offsets when indexing into the buffer + GGML_ASSERT(allow_misalign || misalign_bytes == 0); + offset &= ~misalign_bytes; + size += misalign_bytes; + + return vk_subbuffer{buffer, offset, size}; +} + static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; s.buffer = ggml_vk_create_cmd_buffer(device, p); @@ -7918,12 +7967,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; - auto it = pipelines.find(fa_pipeline_state); - if (it != pipelines.end()) { - pipeline = it->second; - } else { - pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + { + std::lock_guard guard(ctx->device->mutex); + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } } assert(pipeline); @@ -7983,72 +8035,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; - size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; - - bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); - ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); - ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); - Q_uma = d_Q != nullptr; - K_uma = d_K != nullptr; - V_uma = d_V != nullptr; - D_uma = d_D != nullptr; - if (mask) { - ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); - M_uma = d_M != nullptr; - } - if (sinks) { - ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); - S_uma = d_S != nullptr; - } - } - - - ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; - ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; - ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; - - if (!Q_uma) { - d_Q = q_buf_ctx->dev_buffer; - q_buf_offset = vk_tensor_offset(q) + q->view_offs; - } - if (!K_uma) { - d_K = k_buf_ctx->dev_buffer; - k_buf_offset = vk_tensor_offset(k) + k->view_offs; - } - if (!V_uma) { - d_V = v_buf_ctx->dev_buffer; - v_buf_offset = vk_tensor_offset(v) + v->view_offs; - } - if (!D_uma) { - d_D = d_buf_ctx->dev_buffer; - d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; - } - - if (!M_uma) { - d_M = d_Q; - m_buf_offset = q_buf_offset; - if (mask) { - ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; - d_M = m_buf_ctx->dev_buffer; - m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; - } - } - - if (!S_uma) { - d_S = d_Q; - s_buf_offset = q_buf_offset; - if (sinks) { - ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; - d_S = s_buf_ctx->dev_buffer; - s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; - } - } + vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q); + vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k); + vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; + vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; @@ -8070,15 +8062,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_sync_buffers(ctx, subctx); } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), - ggml_vk_subbuffer(ctx, d_K, k_buf_offset), - ggml_vk_subbuffer(ctx, d_V, v_buf_offset), - ggml_vk_subbuffer(ctx, d_M, m_buf_offset), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), - }, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to @@ -8088,23 +8074,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, - { - ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, d_D, d_buf_offset), - }, + {split_k_buf, sinks_buf, dst_buf}, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), - ggml_vk_subbuffer(ctx, d_K, k_buf_offset), - ggml_vk_subbuffer(ctx, d_V, v_buf_offset), - ggml_vk_subbuffer(ctx, d_M, m_buf_offset), - ggml_vk_subbuffer(ctx, d_S, s_buf_offset), - ggml_vk_subbuffer(ctx, d_D, d_buf_offset), - }, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } @@ -8787,35 +8762,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ne01 = src0->ne[1]; const uint64_t ne02 = src0->ne[2]; const uint64_t ne03 = src0->ne[3]; - const uint64_t ne0 = ne00 * ne01; const bool use_src1 = src1 != nullptr; const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; - const uint64_t ne1 = ne10 * ne11; - // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; const bool use_src2 = src2 != nullptr; - const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; - const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; - const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; - const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; - const uint64_t ne2 = ne20 * ne21; - const bool use_src3 = src3 != nullptr; - const uint64_t ne30 = use_src3 ? src3->ne[0] : 0; - const uint64_t ne31 = use_src3 ? src3->ne[1] : 0; - const uint64_t ne32 = use_src3 ? src3->ne[2] : 0; - const uint64_t ne33 = use_src3 ? src3->ne[3] : 0; - const uint64_t ne3 = ne30 * ne31; - - const uint64_t ned0 = dst->ne[0]; - const uint64_t ned1 = dst->ne[1]; - const uint64_t ned2 = dst->ne[2]; - const uint64_t ned3 = dst->ne[3]; - const uint64_t ned = ned0 * ned1; init_pushconst_fastdiv(pc); @@ -8834,74 +8789,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; - ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; - ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; - ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr; + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous); + vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{}; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous); - vk_buffer d_X = nullptr; - size_t x_buf_offset = 0; - vk_buffer d_Y = nullptr; - size_t y_buf_offset = 0; - vk_buffer d_Z = nullptr; - size_t z_buf_offset = 0; - vk_buffer d_W = nullptr; - size_t w_buf_offset = 0; - - bool src0_uma = false; - bool src1_uma = false; - bool src2_uma = false; - bool src3_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); - src0_uma = d_X != nullptr; - if (use_src1) { - ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); - src1_uma = d_Y != nullptr; - } - if (use_src2) { - ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); - src2_uma = d_Z != nullptr; - } - if (use_src3) { - ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset); - src3_uma = d_W != nullptr; - } - } - - vk_buffer d_D = dst_buf_ctx->dev_buffer; - - GGML_ASSERT(d_D != nullptr); - uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; - if(!src0_uma) { - d_X = src0_buf_ctx->dev_buffer; - x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; - GGML_ASSERT(d_X != nullptr); - } - if (use_src1 && !src1_uma) { - d_Y = src1_buf_ctx->dev_buffer; - y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; - GGML_ASSERT(d_Y != nullptr); - } - if (use_src2 && !src2_uma) { - d_Z = src2_buf_ctx->dev_buffer; - z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; - GGML_ASSERT(d_Z != nullptr); - } - if (use_src3 && !src3_uma) { - d_W = src3_buf_ctx->dev_buffer; - w_buf_offset = vk_tensor_offset(src3) + src3->view_offs; - GGML_ASSERT(d_W != nullptr); - } - // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + // Compute misalignment offset for descriptors and store it in in push constants. init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); - x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); std::array elements; @@ -8985,9 +8880,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t KH = ne01; const uint32_t KW = ne00; - const uint32_t OD = ned3 / N; - const uint32_t OH = ned2; - const uint32_t OW = ned1; + const uint32_t OD = dst->ne[3] / N; + const uint32_t OH = dst->ne[2]; + const uint32_t OW = dst->ne[1]; const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; const uint32_t N_OD_OH = N*OD*OH; @@ -9102,112 +8997,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; } - uint64_t x_sz, y_sz, z_sz, w_sz, d_sz; - - if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); - y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; - w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0; - d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); - - if (x_buf_offset + x_sz >= d_X->size) { - x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); - } - if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { - y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); - } - if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { - z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); - } - if (use_src3 && w_buf_offset + w_sz >= d_W->size) { - w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset); - } - if (d_buf_offset + d_sz >= d_D->size) { - d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); - } - } else { - x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; - y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; - z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; - w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0; - d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; - } - if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { - vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; - size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + vk_subbuffer a_buf = src0_buf; + if (ctx->do_add_rms_partials) { + a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset); + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { vk_subbuffer{ d_X, x_buf_offset, x_sz }, - vk_subbuffer{ d_Y, y_buf_offset, y_sz }, - vk_subbuffer{ d_D, d_buf_offset, d_sz }, - ggml_vk_subbuffer(ctx, d_A, a_buf_offset), - }, pc, elements); + { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements); } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer - vk_subbuffer subbuf_y; - if (use_src1) { - subbuf_y = { d_Y, y_buf_offset, y_sz }; - } else { - subbuf_y = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements); } else if (op == GGML_OP_SOFT_MAX) { // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer - vk_subbuffer subbuf_y; - if (use_src1) { - subbuf_y = { d_Y, y_buf_offset, y_sz }; - } else { - subbuf_y = { d_X, 0, x_sz }; - } - - vk_subbuffer subbuf_z; - if (use_src2) { - subbuf_z = { d_Z, z_buf_offset, z_sz }; - } else { - subbuf_z = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf; + vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { - // Empty src2 is possible in rope, but the shader needs a buffer - vk_subbuffer subbuf_z, subbuf_w; - if (use_src2) { - subbuf_z = { d_Z, z_buf_offset, z_sz }; - } else { - subbuf_z = { d_X, 0, x_sz }; - } - if (use_src3) { - subbuf_w = { d_W, w_buf_offset, w_sz }; - } else { - subbuf_w = { d_X, 0, x_sz }; - } - - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements); + // Empty src2 and src3 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf; + vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements); } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { // buffer device address path doesn't use dst buffer - d_sz = 1; + dst_buf.size = 1; } // im2col uses only src1 and dst buffers - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { // count_equal assumes that destination buffer is initialized with zeroes - ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size); ggml_vk_sync_buffers(ctx, subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); } else if (op == GGML_OP_OPT_STEP_SGD) { // OPT_STEP_SGD works on src0, it does not need dst - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements); } else if (use_src3) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements); } else if (use_src2) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements); } else if (use_src1) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements); } else { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements); } } @@ -9443,39 +9276,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[7] = {}; for (int i = 0; i < num_srcs; i++) { - src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; - } - - vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; - size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; - bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; - - if (ctx->device->uma) { - for (int i = 0; i < num_srcs; i++) { - ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); - srcs_uma[i] = d_srcs[i] != nullptr; - } - - ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - dst_uma = d_D != nullptr; - } - - uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; - for (int i = 0; i < num_srcs; i++) { - src_sizes[i] = ggml_nbytes(dst->src[i]); - if (!srcs_uma[i]) { - d_srcs[i] = src_buf_ctxs[i]->dev_buffer; - src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; - } - } - - const uint64_t dst_size = ggml_nbytes(dst); - if (!dst_uma) { - d_D = dst_buf_ctx->dev_buffer; - dst_offset = vk_tensor_offset(dst) + dst->view_offs; + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); } std::array elements = { @@ -9485,26 +9289,13 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx }; if (version == 6) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); } else if (version == 7) { - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } else { // shouldn't happen GGML_ASSERT(false); @@ -9584,40 +9375,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, n_head, head_dim, n_group, n_tok }; - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC]; - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; - } - - vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr }; - size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 }; - bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false }; - - if (ctx->device->uma) { - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); - srcs_uma[i] = d_srcs[i] != nullptr; - } - ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - dst_uma = d_D != nullptr; - } - - if (!dst_uma) { - d_D = dst_buf_ctx->dev_buffer; - dst_offset = vk_tensor_offset(dst) + dst->view_offs; - } - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - if (!srcs_uma[i]) { - d_srcs[i] = src_buf_ctxs[i]->dev_buffer; - src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; - } - } - - size_t dst_size = ggml_nbytes(dst); - size_t src_sizes[GGML_MAX_SRC]; - for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) { - src_sizes[i] = ggml_nbytes(dst->src[i]); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[7] = {}; + for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); } std::array elements; @@ -9627,16 +9388,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, - vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, - vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, - vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, - vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, - vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, - vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf}, + pc, elements); } static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { @@ -9683,66 +9437,17 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; - ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; - ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; - ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; - ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; - - vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; - size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; - bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); - ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); - ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); - ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); - ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); - - X_uma = d_X != nullptr; - G_uma = d_G != nullptr; - GM_uma = d_GM != nullptr; - GV_uma = d_GV != nullptr; - P_uma = d_P != nullptr; - } - - if (!X_uma) { - d_X = x_buf_ctx->dev_buffer; - x_offset = vk_tensor_offset(x) + x->view_offs; - } - if (!G_uma) { - d_G = g_buf_ctx->dev_buffer; - g_offset = vk_tensor_offset(g) + g->view_offs; - } - if (!GM_uma) { - d_GM = gm_buf_ctx->dev_buffer; - gm_offset = vk_tensor_offset(gm) + gm->view_offs; - } - if (!GV_uma) { - d_GV = gv_buf_ctx->dev_buffer; - gv_offset = vk_tensor_offset(gv) + gv->view_offs; - } - if (!P_uma) { - d_P = p_buf_ctx->dev_buffer; - p_offset = vk_tensor_offset(p) + p->view_offs; - } - - const uint64_t x_size = ggml_nbytes(x); - const uint64_t g_size = ggml_nbytes(g); - const uint64_t gm_size = ggml_nbytes(gm); - const uint64_t gv_size = ggml_nbytes(gv); - const uint64_t p_size = ggml_nbytes(p); + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g); + vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm); + vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv); + vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p); std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_X, x_offset, x_size }, - vk_subbuffer{ d_G, g_offset, g_size }, - vk_subbuffer{ d_GM, gm_offset, gm_size }, - vk_subbuffer{ d_GV, gv_offset, gv_size }, - vk_subbuffer{ d_P, p_offset, p_size }, - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {x_buf, g_buf, gm_buf, gv_buf, p_buf}, + pc, elements); } static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { @@ -9938,21 +9643,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g return num_bytes; } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) { +static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) { + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // const int n_ctx = ((const int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + const float freq_base = ((const float *) dst->op_params)[5]; + const float freq_scale = ((const float *) dst->op_params)[6]; + const float ext_factor = ((const float *) dst->op_params)[7]; + const float attn_factor = ((const float *) dst->op_params)[8]; + const float beta_fast = ((const float *) dst->op_params)[9]; + const float beta_slow = ((const float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + + vk_op_rope_push_constants rope { + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + has_ff, (uint32_t)src0->ne[2], nb01, nb02, + { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + }; + + return rope; +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) { + ggml_tensor * dst; + const ggml_tensor * src0; + const ggml_tensor * src1; + + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0]; + dst = mul; + src0 = cgraph->nodes[node_idx]->src[0]; + src1 = other_src; + } else { + dst = cgraph->nodes[node_idx]; + src0 = src1 = dst->src[0]; + } + const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + vk_op_binary_push_constants bin { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], 0.0f, (int32_t)param3, - }); + }; + + // more than one fused op means rms_norm+mul+rope + if (ctx->num_additional_fused_ops > 1) { + static constexpr uint32_t max_tensors = 7; + const ggml_tensor *tensors[max_tensors] {}; + + ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + + bool do_set_rows = ctx->num_additional_fused_ops == 4; + + tensors[0] = rms->src[0]; + tensors[1] = other_src; + tensors[2] = mul; + tensors[3] = rope->src[1]; // pos + tensors[4] = rope->src[2]; // ff + tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst + tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr; + const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0; + + vk_op_rms_norm_mul_rope_push_constants pc; + pc.bin = bin; + pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride); + + vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + ggml_backend_vk_buffer_context * buf_ctx[max_tensors]; + vk_buffer buf[max_tensors]; + size_t offset[max_tensors]; + bool uma[max_tensors]; + + for (uint32_t i = 0; i < max_tensors; ++i) { + if (!tensors[i]) { + // If any remaining descriptors are unused, just point them at src[0] + buf[i] = buf[0]; + offset[i] = 0; + continue; + } + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + + std::array elements; + elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; + + static_assert(max_tensors == 7); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + }, pc, elements); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin)); + } if (ctx->do_add_rms_partials_offset_calculation) { ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); @@ -10074,45 +9907,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context; - ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context; - ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; - - vk_buffer d_logits = nullptr; - size_t logits_buf_offset = 0; - vk_buffer d_weights = nullptr; - size_t weights_buf_offset = 0; - vk_buffer d_ids = nullptr; - size_t ids_buf_offset = 0; - - bool logits_uma = false; - bool weights_uma = false; - bool ids_uma = false; - - if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset); - ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset); - ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); - logits_uma = d_logits != nullptr; - weights_uma = d_weights != nullptr; - ids_uma = d_ids != nullptr; - } - - if (!logits_uma) { - d_logits = logits_buf_ctx->dev_buffer; - logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs; - GGML_ASSERT(d_logits != nullptr); - } - if (!weights_uma) { - d_weights = weights_buf_ctx->dev_buffer; - weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs; - GGML_ASSERT(d_weights != nullptr); - } - if (!ids_uma) { - d_ids = ids_buf_ctx->dev_buffer; - ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; - GGML_ASSERT(d_ids != nullptr); - } + vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); + vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); + vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; @@ -10128,12 +9925,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t rows_per_block = 4; std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset), - ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset), - ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset), - }, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { @@ -10147,9 +9939,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; - const float freq_scale = ((float *) dst->op_params)[6]; - const float ext_factor = ((float *) dst->op_params)[7]; - const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; int sections[4] {}; @@ -10157,16 +9946,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); } - const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; - float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); - uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); - uint32_t set_rows_stride = 0; // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride // and overrides the dst and sets src3=row_indices @@ -10176,12 +9958,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons dst = cgraph->nodes[node_idx + 2]; } - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, { - (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, - }); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, + ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride)); } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11696,6 +11474,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (n->op == GGML_OP_GLU) { std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; } + if (n->op == GGML_OP_ROPE) { + const int mode = ((const int32_t *) n->op_params)[2]; + std::cerr << " rope mode: " << mode; + } std::cerr << std::endl; } #endif @@ -11803,14 +11585,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_RMS_NORM: - if (ctx->num_additional_fused_ops > 0) { - // fused rms_norm + mul - ggml_tensor *mul = cgraph->nodes[node_idx + 1]; - ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; - ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params); - } else { - ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params); - } + ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params); break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node); @@ -12796,6 +12571,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Check whether the tensors overlap in memory but are not equal. +// Fusions can potenitally overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them +// to overlap if they are exactly equal. +// XXX TODO this check is probably missing from several fusion optimizations. +static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; + vk_buffer b_buf = b_buf_ctx->dev_buffer; + if (a_buf == b_buf) { + auto a_base = vk_tensor_offset(a) + a->view_offs; + auto a_size = ggml_nbytes(a); + auto b_base = vk_tensor_offset(b) + b->view_offs; + auto b_size = ggml_nbytes(b); + + if (a_base == b_base && a_size == b_size) { + return false; + } + + if ((b_base <= a_base && a_base < b_base + b_size) || + (a_base <= b_base && b_base < a_base + a_size)) { + return true; + } + } + return false; +} + +static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx) { + GGML_UNUSED(ctx); + const ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + const int mode = ((const int32_t *) rope->op_params)[2]; + + // noncontig tensors aren't tested, and don't seem common in practice + if (!ggml_is_contiguous(rms) || + !ggml_is_contiguous(mul) || + !ggml_is_contiguous(rope)) { + return false; + } + + // only norm/neox are handled in the shader + if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) { + return false; + } + + // shared memory size for passing data from mul->rope + if (mul->ne[0] > 1024) { + return false; + } + + // must not overwrite srcs in a way that's not elementwise + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || + ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12941,12 +12780,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; - } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && + ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && + ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { + ctx->num_additional_fused_ops = 4; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 2; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -13179,14 +13026,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } if (ok) { current_set.push_back(j); + + int rope_idx = j; + + // When we've found RMS_NORM + MUL, try to find a ROPE that uses it + if (j > 0 && + graph->nodes[j]->op == GGML_OP_MUL && + graph->nodes[j-1]->op == GGML_OP_RMS_NORM) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_ROPE && + graph->nodes[k]->src[0] == graph->nodes[j] && + // Check that other srcs are already valid + graph->nodes[k]->src[1]->op == GGML_OP_NONE && + (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) { + rope_idx = k; + current_set.push_back(rope_idx); + used[rope_idx] = true; + break; + } + } + } // Look for ROPE + VIEW + SET_ROWS and make them consecutive - if (graph->nodes[j]->op == GGML_OP_ROPE) { + if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) { int view_idx = -1; int set_rows_idx = -1; - for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) { + for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) { if (view_idx == -1 && graph->nodes[k]->op == GGML_OP_VIEW && - graph->nodes[k]->src[0] == graph->nodes[j]) { + graph->nodes[k]->src[0] == graph->nodes[rope_idx]) { view_idx = k; continue; } @@ -14134,20 +14001,11 @@ size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { - ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops]; if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } - bool fused_rms_norm_mul = false; - int rms_norm_idx = -1; - if (ctx->num_additional_fused_ops == 1 && - tensor->op == GGML_OP_RMS_NORM && - cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { - fused_rms_norm_mul = true; - tensor = cgraph->nodes[tensor_idx + 1]; - } - check_counter++; if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; @@ -14155,9 +14013,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); - ggml_tensor * src0 = tensor->src[0]; - ggml_tensor * src1 = tensor->src[1]; - struct ggml_init_params iparams = { /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, /*.mem_buffer =*/ NULL, @@ -14167,328 +14022,339 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * struct ggml_context * ggml_ctx = ggml_init(iparams); std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; - std::array src_size = {}; - std::array src_buffer = {}; const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"}; + std::map cloned_tensors; + std::vector cloned_mallocs; + struct ggml_tensor * tensor_clone = nullptr; - for (int i = 0; i < GGML_MAX_SRC; i++) { - ggml_tensor * srci = tensor->src[i]; - if (fused_rms_norm_mul) { - rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; - ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; - switch (i) { - case 0: srci = rms_norm->src[0]; break; - case 1: srci = tensor->src[1 - rms_norm_idx]; break; - default: continue; + for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) { + tensor = cgraph->nodes[tensor_idx + f]; + for (int i = 0; i < GGML_MAX_SRC; i++) { + ggml_tensor * srci = tensor->src[i]; + if (srci == nullptr) { + continue; } - } - if (srci == nullptr) { - continue; - } - ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); - size_t srci_size = ggml_nbytes(srci); + // If a src tensor has been cloned, use that one + auto it = cloned_tensors.find(srci); + if (it != cloned_tensors.end()) { + src_clone[i] = it->second; + continue; + } + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); - src_clone[i] = srci_clone; - src_size[i] = ggml_nbytes(srci); - src_buffer[i] = malloc(srci_size); + src_clone[i] = srci_clone; + void *src_buffer = malloc(srci_size); + cloned_mallocs.push_back(src_buffer); - srci_clone->data = src_buffer[i]; - if (ggml_backend_buffer_is_host(srci->buffer)) { - memcpy(srci_clone->data, srci->data, srci_size); - memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); - } else if (ggml_backend_buffer_is_vk(srci->buffer)) { - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; - vk_buffer& buffer_gpu = buf_ctx->dev_buffer; - uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; - if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { - for (int i3 = 0; i3 < srci->ne[3]; i3++) { - for (int i2 = 0; i2 < srci->ne[2]; i2++) { - const int idx = i3*srci->ne[2] + i2; - ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); - } - } - - srci_clone->nb[0] = srci->nb[0]; - srci_clone->nb[1] = srci->nb[1]; - for (int i = 2; i < GGML_MAX_DIMS; i++) { - srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; - } - } else { - if (offset + srci_size >= buffer_gpu->size) { - srci_size = buffer_gpu->size - offset; - } - ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + srci_clone->data = src_buffer; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); + } + } + + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; + } + } else { + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(srci, srci_name[i]); } - } else { - GGML_ABORT("fatal error"); } - if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { - ggml_vk_print_tensor(srci, srci_name[i]); - } - } - - if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); - if (src_clone[4]) { - ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); - } - } else if (tensor->op == GGML_OP_MUL_MAT) { - tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_MUL_MAT_ID) { - tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); - } else if (tensor->op == GGML_OP_SUB) { - tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_MUL) { - if (fused_rms_norm_mul) { - tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); - tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); - } else { - tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); - } - } else if (tensor->op == GGML_OP_DIV) { - tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_CONCAT) { - tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); - } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); - } else if (tensor->op == GGML_OP_SCALE) { - const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); - } else if (tensor->op == GGML_OP_SQR) { - tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_SQRT) { - tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_SIN) { - tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_COS) { - tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_CLAMP) { - const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); - } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], - tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); - } else if (tensor->op == GGML_OP_REPEAT) { - tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); - } else if (tensor->op == GGML_OP_REPEAT_BACK) { - tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); - } else if (tensor->op == GGML_OP_ADD) { - tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_ACC) { - tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); - } else if (tensor->op == GGML_OP_NORM) { - tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); - } else if (tensor->op == GGML_OP_GROUP_NORM) { - const float * float_params = (const float *)tensor->op_params; - tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); - } else if (tensor->op == GGML_OP_RMS_NORM) { - tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); - } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { - const float eps = ((float *) tensor->op_params)[0]; - tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); - } else if (tensor->op == GGML_OP_SILU_BACK) { - tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_L2_NORM) { - const float eps = ((float *) tensor->op_params)[0]; - tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); - } else if (tensor->op == GGML_OP_SOFT_MAX) { - if (src1 != nullptr) { + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float * params = (const float *)tensor->op_params; - tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); - } else { - tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); - } - } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { - tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); - } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); - } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; - const float freq_base = ((float *) tensor->op_params)[5]; - const float freq_scale = ((float *) tensor->op_params)[6]; - const float ext_factor = ((float *) tensor->op_params)[7]; - const float attn_factor = ((float *) tensor->op_params)[8]; - const float beta_fast = ((float *) tensor->op_params)[9]; - const float beta_slow = ((float *) tensor->op_params)[10]; - if (mode & GGML_ROPE_TYPE_MROPE) { - int32_t *sections = ((int32_t *) tensor->op_params) + 11; - if (tensor->op == GGML_OP_ROPE) { - tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - } else { - tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + if (src_clone[4]) { + ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); } - } else { - if (tensor->op == GGML_OP_ROPE) { - tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL) { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + } else if (tensor->op == GGML_OP_SCALE) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SQRT) { + tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CLAMP) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (tensor->src[1] != nullptr) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); } else { - tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } else { + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_ERF: + tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSIGMOID: + tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSWISH: + tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } + ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); + ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (tensor->src[1] == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_2D_DW) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { + const int32_t s = tensor->op_params[0]; + tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = tensor->src[0]->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); + } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { + src_clone[0]->flags = tensor->src[0]->flags; + tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2]); + } else if (tensor->op == GGML_OP_ADD_ID) { + tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SSM_SCAN) { + tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], + src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_SSM_CONV) { + tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ROLL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t s3 = tensor->op_params[3]; + tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3); } - } else if (tensor->op == GGML_OP_UNARY) { - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_EXP: - tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_SILU: - tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_GELU: - tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_GELU_ERF: - tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_GELU_QUICK: - tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_RELU: - tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_TANH: - tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_SIGMOID: - tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_HARDSIGMOID: - tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); - break; - case GGML_UNARY_OP_HARDSWISH: - tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); - break; - default: + else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } - } else if (tensor->op == GGML_OP_GLU) { - if (src_clone[1] == nullptr) { - tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); - } else { - tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); - } - ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); - ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); - } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { - if (src1 == nullptr) { - tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); - tensor_clone->type = tensor->type; - } else { - tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); - } - } else if (tensor->op == GGML_OP_CONT) { - tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); - } else if (tensor->op == GGML_OP_RESHAPE) { - tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); - } else if (tensor->op == GGML_OP_VIEW) { - tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); - } else if (tensor->op == GGML_OP_PERMUTE) { - int32_t * params = (int32_t *)tensor->op_params; - tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); - } else if (tensor->op == GGML_OP_TRANSPOSE) { - tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_GET_ROWS) { - tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_ARGSORT) { - tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); - } else if (tensor->op == GGML_OP_SUM) { - tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_SUM_ROWS) { - tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_MEAN) { - tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_ARGMAX) { - tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); - } else if (tensor->op == GGML_OP_COUNT_EQUAL) { - tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); - } else if (tensor->op == GGML_OP_IM2COL) { - const int32_t s0 = tensor->op_params[0]; - const int32_t s1 = tensor->op_params[1]; - const int32_t p0 = tensor->op_params[2]; - const int32_t p1 = tensor->op_params[3]; - const int32_t d0 = tensor->op_params[4]; - const int32_t d1 = tensor->op_params[5]; - - const bool is_2D = tensor->op_params[6] == 1; - tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); - } else if (tensor->op == GGML_OP_IM2COL_3D) { - const int32_t s0 = tensor->op_params[0]; - const int32_t s1 = tensor->op_params[1]; - const int32_t s2 = tensor->op_params[2]; - const int32_t p0 = tensor->op_params[3]; - const int32_t p1 = tensor->op_params[4]; - const int32_t p2 = tensor->op_params[5]; - const int32_t d0 = tensor->op_params[6]; - const int32_t d1 = tensor->op_params[7]; - const int32_t d2 = tensor->op_params[8]; - const int32_t IC = tensor->op_params[9]; - - tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); - } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { - const int32_t dim = tensor->op_params[0]; - const int32_t max_period = tensor->op_params[1]; - tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); - } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ - const int32_t s0 = tensor->op_params[0]; - const int32_t p0 = tensor->op_params[1]; - const int32_t d0 = tensor->op_params[2]; - tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); - } else if (tensor->op == GGML_OP_POOL_2D) { - enum ggml_op_pool op = static_cast(tensor->op_params[0]); - const int32_t k0 = tensor->op_params[1]; - const int32_t k1 = tensor->op_params[2]; - const int32_t s0 = tensor->op_params[3]; - const int32_t s1 = tensor->op_params[4]; - const int32_t p0 = tensor->op_params[5]; - const int32_t p1 = tensor->op_params[6]; - - tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); - } else if (tensor->op == GGML_OP_CONV_2D) { - const int32_t s0 = tensor->op_params[0]; - const int32_t s1 = tensor->op_params[1]; - const int32_t p0 = tensor->op_params[2]; - const int32_t p1 = tensor->op_params[3]; - const int32_t d0 = tensor->op_params[4]; - const int32_t d1 = tensor->op_params[5]; - tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); - } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { - const int32_t s = tensor->op_params[0]; - tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); - } else if (tensor->op == GGML_OP_LEAKY_RELU) { - const float * op_params = (const float *)tensor->op_params; - tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); - } else if (tensor->op == GGML_OP_RWKV_WKV6) { - tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4], src_clone[5]); - } else if (tensor->op == GGML_OP_RWKV_WKV7) { - tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], - src_clone[4], src_clone[5], src_clone[6]); - } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { - src_clone[0]->flags = src0->flags; - tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4]); - } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { - src_clone[0]->flags = src0->flags; - tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2]); - } else if (tensor->op == GGML_OP_ADD_ID) { - tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); - } else if (tensor->op == GGML_OP_SSM_SCAN) { - tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], - src_clone[3], src_clone[4], src_clone[5], src_clone[6]); - } else if (tensor->op == GGML_OP_SSM_CONV) { - tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]); - } - else { - std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; - GGML_ABORT("fatal error"); + cloned_tensors[tensor] = tensor_clone; } ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); @@ -14506,10 +14372,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * memcpy(comp_result, tensor_clone->data, comp_size); memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (src_buffer[i] != nullptr) { - free(src_buffer[i]); - } + for (auto m : cloned_mallocs) { + free(m); } ggml_free(ggml_ctx); @@ -14518,15 +14382,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { - ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops]; if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { return; } - if (ctx->num_additional_fused_ops == 1 && - tensor->op == GGML_OP_RMS_NORM && - cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { - tensor = cgraph->nodes[tensor_idx + 1]; - } if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 99595fc68..c1ad51725 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -3,6 +3,9 @@ #include "rte.glsl" #include "utils.glsl" +#if RMS_NORM_ROPE_FUSION +#include "rope_params.glsl" +#endif layout (push_constant) uniform parameter { @@ -12,11 +15,16 @@ layout (push_constant) uniform parameter uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint misalign_offsets; float param1; float param2; int param3; +#if RMS_NORM_ROPE_FUSION + rope_params rope; +#endif } p; +#if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#endif // true if src0/src1 are the same shape and the indices can be reused without additional modulus layout(constant_id = 0) const bool norepeat = false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index d260969f0..5c5251da3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -100,7 +100,6 @@ layout (push_constant) uniform parameter layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; -layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; @@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#define BK 32 +#define BK_STEP 4 +#else +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +#define BK_STEP 2 +#endif + #ifdef COOPMAT #define SHMEM_STRIDE (BK / 2 + 4) #else @@ -244,8 +251,13 @@ void main() { } #else ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; +#if defined(DATA_A_F32) || defined(DATA_A_F16) + FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC4 cache_b; +#else FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_b; +#endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); @@ -283,24 +295,41 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < BK / 2; i++) { + [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ]; + cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + #endif } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ]; + cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + #endif [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + #if defined(DATA_A_F32) || defined(DATA_A_F16) + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); + #else sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); + #endif } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index d5b211ffa..3a47949d5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -3,6 +3,32 @@ #include "generic_binary_head.glsl" #include "types.glsl" +#if RMS_NORM_ROPE_FUSION + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; + +// data is passed from rms_norm -> rope through shared memory. +// rms_norm calls this data_d, rope calls this rope_data_a. +// Binding 2 is not used +shared FLOAT_TYPE rope_data_a[1024]; +#define data_d rope_data_a + +layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];}; +layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];}; +layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows + +#include "rope_params.glsl" +#include "rope_funcs.glsl" + +#define GGML_ROPE_TYPE_NORMAL 0 +#define GGML_ROPE_TYPE_NEOX 2 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 + +#endif + #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 @@ -28,8 +54,12 @@ void rms_norm(uint num_iters) { uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + uint32_t d_offset = 0; +#else uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - +#endif FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { @@ -79,6 +109,18 @@ void rms_norm(uint num_iters) { data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } +#if RMS_NORM_ROPE_FUSION + barrier(); + rope_params rp = p.rope; + uint rope_row = (samp*nchannels + channel)*nrows + row; + for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { + if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { + rope_neox(t, rope_row, rp); + } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { + rope_norm(t, rope_row, rp); + } + } +#endif } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl new file mode 100644 index 000000000..9726b722d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -0,0 +1,227 @@ + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + const uint ix = i0; +#else + const uint ix = i02*p.nb02 + i01*p.nb01 + i0; +#endif + return ix; +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} + +void rope_norm(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0; + const uint ix = rope_a_coord(i0, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); + rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + 1]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_neox(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0/2; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + + +void rope_multi(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (p.is_imrope != 0) { + if (sector % 3 == 1 && sector < 3 * p.sections[1]) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } else { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } else { + if (sector < p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_vision(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index fa2bb3339..d9b4d4c03 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -3,56 +3,18 @@ #extension GL_EXT_shader_16bit_storage : require #include "rte.glsl" +#include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_pos[];}; -layout (binding = 2) readonly buffer Z {float data_ff[];}; -layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; -layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows +layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];}; +layout (binding = 1) readonly buffer Y {int rope_data_pos[];}; +layout (binding = 2) readonly buffer Z {float rope_data_ff[];}; +layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows + layout (push_constant) uniform parameter { - uint ncols; - uint n_dims; - float freq_scale; - uint p_delta_rows; - float freq_base; - float ext_factor; - float attn_factor; - float corr_dims[2]; - float theta_scale; - uint has_ff; - uint ne02; - uint s1; - uint s2; - int sections[4]; - uint is_imrope; - uint is_back; - uint set_rows_stride; -} p; + rope_params pc; +}; -float rope_yarn_ramp(const float low, const float high, const uint i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { - float mscale = p.attn_factor; - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = p.freq_scale * theta_extrap; - float theta = theta_interp; - if (p.ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); - } - // Backprogagation uses inverted rotation - if (p.is_back != 0) { - theta = -theta; - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 54aabcf22..7c1fb1cd2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,70 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; - data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; - - return; - } - - const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (p.is_imrope != 0) { - if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } else { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } else { - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_multi(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 9f4538155..68f00c180 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0/2; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]); - data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_neox(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index f4209ed95..28a939ec6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + 0] = D_TYPE(data_a[ix + 0]); - data_d[idst + 1] = D_TYPE(data_a[ix + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + 1]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_norm(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl new file mode 100644 index 000000000..82f39cee3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -0,0 +1,27 @@ +#if !defined(GGML_ROPE_PARAMS) +#define GGML_ROPE_PARAMS + +#include "rte.glsl" + +struct rope_params { + uint rope_mode; + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint nb01; + uint nb02; + int sections[4]; + uint is_imrope; + uint is_back; + uint set_rows_stride; +}; + +#endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d37d1c104..ea1e0fdb4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,47 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - const int sect_dims = p.sections[0] + p.sections[1]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (sector < p.sections[0]) { - const uint p0 = sector; - theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); - } - else if (sector >= p.sections[0] && sector < sec_w) { - const uint p0 = sector - p.sections[0]; - theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_vision(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fc4560fba..f2f5e966b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -713,6 +713,8 @@ void process_shaders() { string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -858,25 +860,25 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index b819e6169..d6b1e5644 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -2473,10 +2473,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in set_clip_uses_gpu(false); printf("Clip forced to use CPU!\n"); } - clip_init_result cres = clip_init(mmproj_filename.c_str(), clip_context_params{ - /* use_gpu */ true, - /* verbosity */ static_cast(1), - }); + clip_context_params ctx_clip_params { + /* use_gpu */ true, + /* verbosity */ static_cast(1), + /* flash_attn_type */ (kcpp_data->flash_attn?CLIP_FLASH_ATTN_TYPE_ENABLED:CLIP_FLASH_ATTN_TYPE_DISABLED), + /* image_min_tokens */ -1, + /* image_max_tokens */ -1, + }; + clip_init_result cres = clip_init(mmproj_filename.c_str(), ctx_clip_params); clp_ctx_v = cres.ctx_v; clp_ctx_a = cres.ctx_a; if(clp_ctx_v == nullptr && clp_ctx_a == nullptr) { diff --git a/include/llama.h b/include/llama.h index e7bce1151..12666cfd3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -466,6 +466,7 @@ extern "C" { // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions // In some cases the requested values via llama_context_params may differ from the actual values used by the context + // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); @@ -488,6 +489,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4813f6fd1..633e75631 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -21,6 +21,8 @@ llama_context::llama_context( llama_context_params params) : model(model), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { + // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, + // may need to be backend-dependent LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); t_start_us = model.t_start_us; @@ -112,10 +114,14 @@ llama_context::llama_context( } } + // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); + if (cparams.kv_unified) { cparams.n_ctx_seq = cparams.n_ctx; } else { cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; + cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); if (cparams.n_ctx_seq == 0) { throw std::runtime_error("n_ctx_seq == 0"); @@ -823,7 +829,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -992,7 +998,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_vocab = vocab.n_tokens(); - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd_inp(); // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; @@ -2150,7 +2156,7 @@ void llama_context::opt_epoch_iter( batch.logits [pos_batch] = true; } - if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f9751b318..b199e9462 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1142,7 +1142,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd_inp(); auto inp = std::make_unique(); @@ -1279,7 +1279,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { // return cur; //} - const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd; + const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 514d65384..8cdbaf69f 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -60,6 +60,16 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { return n_head/n_head_kv; } +uint32_t llama_hparams::n_embd_inp() const { + uint32_t n_embd_inp = n_embd; + + if (n_deepstack_layers > 0) { + n_embd_inp += n_embd * n_deepstack_layers; + } + + return n_embd_inp; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 539fecb3f..9203af83b 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -227,6 +227,9 @@ struct llama_hparams { uint32_t n_gqa(uint32_t il = 0) const; + // dimension of main + auxiliary input embeddings + uint32_t n_embd_inp() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 88eb96039..885e77fe1 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); + // note: the SWA cache is always padded to 256 for performance + // https://github.com/ggml-org/llama.cpp/issues/17037 + uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256); //kcpp: pad the swa kv cache as well, similar to extra_context_handle_fragmentation size_swa += 128; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 467b4de7a..566dfcb2c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -121,6 +121,7 @@ #include "models/wavtokenizer-dec.cpp" #include "models/xverse.cpp" #include "models/graph-context-mamba.cpp" +#include "models/pangu-embedded.cpp" #if defined(GGML_USE_CLBLAST) # include "ggml_v3b-opencl.h" @@ -376,8 +377,8 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_IM2COL: { - const int n_embd = hparams.n_embd; - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + const int n_embd_inp = hparams.n_embd_inp(); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; case GGML_OP_SCALE: @@ -1139,9 +1140,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; } - // since vision model stacks deepstack features along feature dim - // we also create a fake "n_embd" for text model to be the main embd + deepstack embds - hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_QWEN3MOE: { @@ -1165,9 +1163,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } - // since vision model stacks deepstack features along feature dim - // we also create a fake "n_embd" for text model to be the main embd + deepstack embds - hparams.n_embd *= hparams.n_deepstack_layers + 1; } break; case LLM_ARCH_PHI2: { @@ -3494,10 +3489,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3VL: { - // for model loading, the weights only have the main embd - // so we need to divide by the number of deepstack layers + 1 - // n_embd is const int so we declare a new variable - int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3533,10 +3524,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_QWEN3MOE: case LLM_ARCH_QWEN3VLMOE: { - // for model loading, the weights only have the main embd - // so we need to divide by the number of deepstack layers + 1 - // n_embd is const int so we declare a new variable - int64_t n_embd = hparams.n_embd / (hparams.n_deepstack_layers + 1); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -6689,6 +6676,7 @@ void llama_model::print_info() const { if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); @@ -7537,6 +7525,10 @@ int32_t llama_model_n_embd(const llama_model * model) { return model->hparams.n_embd; } +int32_t llama_model_n_embd_inp(const llama_model * model) { + return model->hparams.n_embd_inp(); +} + int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer; } diff --git a/src/models/qwen3vl-moe.cpp b/src/models/qwen3vl-moe.cpp index c48643c0c..f72f80a83 100644 --- a/src/models/qwen3vl-moe.cpp +++ b/src/models/qwen3vl-moe.cpp @@ -1,9 +1,8 @@ #include "models.h" llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/qwen3vl.cpp b/src/models/qwen3vl.cpp index 10b36c1f6..0bae52239 100644 --- a/src/models/qwen3vl.cpp +++ b/src/models/qwen3vl.cpp @@ -1,13 +1,10 @@ #include "models.h" llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - - const int64_t n_embd_full = hparams.n_embd; // main embd + deepstack embds const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = n_embd_full / (n_deepstack_layers + 1); + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 38a6827cd..734b9e15e 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1110,16 +1110,24 @@ struct clip_graph { } ggml_cgraph * build_minicpmv() { - const int batch_size = 1; - GGML_ASSERT(model.class_embedding == nullptr); - const int n_pos = n_patches; + const int n_pos = n_patches; + const int n_embd_proj = clip_n_mmproj_embd(ctx); // position embeddings for the projector (not for ViT) - int n_output_dim = clip_n_mmproj_embd(ctx); - ggml_tensor * pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, n_pos, batch_size); - ggml_set_name(pos_embed, "pos_embed"); - ggml_set_input(pos_embed); + // see: https://huggingface.co/openbmb/MiniCPM-o-2_6/blob/main/resampler.py#L70 + // base frequency omega + ggml_tensor * omega = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_proj / 4); + ggml_set_name(omega, "omega"); + ggml_set_input(omega); + + // 2D input positions (using float for sinusoidal embeddings) + ggml_tensor * pos_h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + ggml_tensor * pos_w = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); // for selecting learned pos embd, used by ViT struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); @@ -1130,7 +1138,7 @@ struct clip_graph { ggml_tensor * inp = build_inp(); ggml_tensor * embeddings = build_vit( - inp, n_patches, + inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, learned_pos_embd, @@ -1142,17 +1150,39 @@ struct clip_graph { ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); // norm - q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1); + q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1); v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1); + // calculate sinusoidal pos embd + ggml_tensor * pos_embed = nullptr; + { + // outer product + ggml_tensor * omega_b = ggml_repeat_4d(ctx0, omega, omega->ne[0], n_pos, 1, 1); // n_pos rows + ggml_tensor * theta_x = ggml_mul(ctx0, omega_b, pos_w); + ggml_tensor * theta_y = ggml_mul(ctx0, omega_b, pos_h); + // sin and cos + ggml_tensor * pos_embd_x = ggml_concat( + ctx0, + ggml_sin(ctx0, theta_x), + ggml_cos(ctx0, theta_x), + 0 // concat on first dim + ); + ggml_tensor * pos_embd_y = ggml_concat( + ctx0, + ggml_sin(ctx0, theta_y), + ggml_cos(ctx0, theta_y), + 0 // concat on first dim + ); + pos_embed = ggml_concat(ctx0, pos_embd_x, pos_embd_y, 0); + } + // k = v + pos_embed ggml_tensor * k = ggml_add(ctx0, v, pos_embed); // attention { - int n_embd = clip_n_mmproj_embd(ctx); const int d_head = 128; - int n_head = n_embd/d_head; + int n_head = n_embd_proj/d_head; // Use actual config value if available, otherwise fall back to hardcoded values int num_query = ctx->model.hparams.minicpmv_query_num; ggml_tensor * Q = ggml_add(ctx0, @@ -2842,17 +2872,11 @@ struct clip_model_loader { get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // only 2.5 requires it // ref: https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json - // the actual max limit is 12845056/14/14/2/2/4 = 4096 tokens - // but we set a lower value to avoid OOM - // TODO: make it configurable by user - // TODO (2): bbox coordinates become inaccurate with small number of tokens, - // therefore we need to increase the min_tokens - // see: https://github.com/ggml-org/llama.cpp/issues/16842#issuecomment-3475144858 - if (model.proj_type==PROJECTOR_TYPE_QWEN25VL && q25vl_migrated) { + if (model.proj_type==PROJECTOR_TYPE_QWEN25VL && q25vl_migrated) { hparams.n_wa_pattern = 8; } - hparams.set_limit_image_tokens(8, 2048); - hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup const int warn_min_pixels = 1024 * hparams.n_merge * hparams.n_merge * hparams.patch_size * hparams.patch_size; if (hparams.image_min_pixels < warn_min_pixels) { LOG_WRN("%s: Qwen-VL models require at minimum 1024 image tokens to function correctly on grounding tasks\n", __func__); @@ -4749,92 +4773,6 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im return n_patches; } -static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { - assert(embed_dim % 2 == 0); - int H = pos.size(); - int W = pos[0].size(); - - std::vector omega(embed_dim / 2); - for (int i = 0; i < embed_dim / 2; ++i) { - omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); - } - - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - float out_value = pos[h][w] * omega[d]; - emb[h][w][d] = sin(out_value); - emb[h][w][d + embed_dim / 2] = cos(out_value); - } - } - } - - return emb; -} - -static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) { - assert(embed_dim % 2 == 0); - std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) - std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) - - int H = emb_h.size(); - int W = emb_h[0].size(); - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - emb[h][w][d] = emb_h[h][w][d]; - emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; - } - } - } - return emb; -} - -static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { - int grid_h_size = image_size.first; - int grid_w_size = image_size.second; - - std::vector grid_h(grid_h_size); - std::vector grid_w(grid_w_size); - - for (int i = 0; i < grid_h_size; ++i) { - grid_h[i] = static_cast(i); - } - for (int i = 0; i < grid_w_size; ++i) { - grid_w[i] = static_cast(i); - } - - std::vector> grid(grid_h_size, std::vector(grid_w_size)); - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid[h][w] = grid_w[w]; - } - } - std::vector>> grid_2d = {grid, grid}; - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid_2d[0][h][w] = grid_h[h]; - grid_2d[1][h][w] = grid_w[w]; - } - } - - std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); - - int H = image_size.first; - int W = image_size.second; - std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; - } - } - - return pos_embed_2d; -} - bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { clip_image_f32_batch imgs; clip_image_f32_ptr img_copy(clip_image_f32_init()); @@ -4973,27 +4911,33 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_i32("positions", positions); - // inspired from resampler of Qwen-VL: - // -> https://huggingface.co/Qwen/Qwen-VL/tree/main - // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 - int embed_dim = clip_n_mmproj_embd(ctx); - - // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? - auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); - - std::vector pos_embed(embed_dim * pos_w * pos_h); - for(int i = 0; i < pos_w * pos_h; ++i){ - for(int j = 0; j < embed_dim; ++j){ - pos_embed[i * embed_dim + j] = pos_embed_t[i][j]; - } + // inputs for resampler projector + // set the 2D positions (using float for sinusoidal embedding) + int n_patches_per_col = image_size_width / patch_size; + std::vector pos_data(n_pos); + // dimension H + for (int i = 0; i < n_pos; i++) { + pos_data[i] = static_cast(i / n_patches_per_col); } - - set_input_f32("pos_embed", pos_embed); + set_input_f32("pos_h", pos_data); + // dimension W + for (int i = 0; i < n_pos; i++) { + pos_data[i] = static_cast(i % n_patches_per_col); + } + set_input_f32("pos_w", pos_data); + // base frequency omega + const float base_freq = 10000.0f; + const int n_embd_proj = clip_n_mmproj_embd(ctx); + std::vector omega(n_embd_proj / 4); + for (int i = 0; i < n_embd_proj / 4; ++i) { + omega[i] = 1.0f / std::pow(base_freq, static_cast(i) / (n_embd_proj / 4)); + } + set_input_f32("omega", omega); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN3VL: { - const int merge_ratio = 2; + const int merge_ratio = hparams.n_merge; const int pw = image_size_width / patch_size; const int ph = image_size_height / patch_size; std::vector positions(n_pos * 4); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4343f3b6f..e59913776 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -163,7 +163,7 @@ struct mtmd_context { print_timings(ctx_params.print_timings), n_threads (ctx_params.n_threads), media_marker (ctx_params.media_marker), - n_embd_text (llama_model_n_embd(text_model)) + n_embd_text (llama_model_n_embd_inp(text_model)) { if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) { throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead"); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 678aad93b..9d91e32d1 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2400,7 +2400,7 @@ struct server_context { add_bos_token = llama_vocab_get_add_bos(vocab); - if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { + if (params_base.has_speculative()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); auto params_dft = params_base; @@ -2476,7 +2476,7 @@ struct server_context { SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); } - if (!params_base.speculative.model.path.empty()) { + if (params_base.has_speculative()) { SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); return false; } @@ -2520,6 +2520,7 @@ struct server_context { if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { SRV_ERR("%s", "failed to create draft context\n"); @@ -2822,9 +2823,12 @@ struct server_context { send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } + + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); } // initialize draft batch + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); @@ -3074,7 +3078,7 @@ struct server_context { res->progress.total = slot.task->n_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; } else { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; @@ -3830,7 +3834,9 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, n_past - n_swa); - if (n_past > 0 && n_past < slot.prompt.n_tokens()) { + // note: disallow with mtmd contexts for now + // https://github.com/ggml-org/llama.cpp/issues/17043 + if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); @@ -4291,6 +4297,8 @@ struct server_context { } // do speculative decoding + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { continue; @@ -4445,8 +4453,10 @@ int main(int argc, char ** argv) { // TODO: should we have a separate n_parallel parameter for the server? // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - if (params.n_parallel == 1 && params.kv_unified == false) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__); + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); params.n_parallel = 4; params.kv_unified = true; diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index 65952de8b..d2f3fba5f 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -77,10 +77,10 @@ def test_different_draft_min_draft_max(): def test_slot_ctx_not_exceeded(): global server - server.n_ctx = 64 + server.n_ctx = 256 server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, + "prompt": "Hello " * 248, "temperature": 0.0, "top_k": 1, "speculative.p_min": 0.0, @@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded(): def test_with_ctx_shift(): global server - server.n_ctx = 64 + server.n_ctx = 256 server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, + "prompt": "Hello " * 248, "temperature": 0.0, "top_k": 1, - "n_predict": 64, + "n_predict": 256, "speculative.p_min": 0.0, }) assert res.status_code == 200 assert len(res.body["content"]) > 0 - assert res.body["tokens_predicted"] == 64 + assert res.body["tokens_predicted"] == 256 assert res.body["truncated"] == True