diff --git a/common/arg.cpp b/common/arg.cpp index c7440728c..558c9697a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -4,7 +4,6 @@ #include "chat.h" #include "common.h" #include "download.h" -#include "hf-cache.h" #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" @@ -539,7 +538,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); } if (!seen_args.insert(arg).second) { - LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str()); + const bool skip = (arg == "--spec-type"); + + if (!skip) { + LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str()); + } } auto & tmp = arg_to_options[arg]; auto opt = *tmp.first; @@ -588,12 +591,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // parse the first time to get -hf option (used for remote preset) parse_cli_args(); - // TODO: Remove later // KCPP: remove for now - // try { - // hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline); - // } catch (const std::exception & e) { - // LOG_WRN("HF cache migration failed: %s\n", e.what()); - // } // export_graph_ops loads only metadata const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; @@ -902,7 +899,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map samplers_seq_config; }; -common_init_result::common_init_result(common_params & params) : +common_init_result::common_init_result(common_params & params, bool model_only) : pimpl(new impl{}) { auto mparams = common_model_params_to_llama(params); auto cparams = common_context_params_to_llama(params); @@ -1179,7 +1179,7 @@ common_init_result::common_init_result(common_params & params) : params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, - params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); + params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); @@ -1189,6 +1189,10 @@ common_init_result::common_init_result(common_params & params) : pimpl->model.reset(model); + if (model_only) { + return; + } + const llama_vocab * vocab = llama_model_get_vocab(model); // load and optionally apply lora adapters @@ -1258,29 +1262,6 @@ common_init_result::common_init_result(common_params & params) : cparams.n_samplers = pimpl->samplers_seq_config.size(); } - // [TAG_RS_STATE_ROLLBACK_SUPPORT] - // TODO: ngram speculative methods require checkpointing in addition to partial RS rollback - // currently this is not supported. so we disable the partial rollback - if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) { - auto & types = params.speculative.types; - - for (int i = 0; i < (int) types.size(); i++) { - if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) { - continue; - } - if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) { - continue; - } - - cparams.n_rs_seq = 0; - - LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__, - common_speculative_type_to_str(types[i]).c_str()); - - break; - } - } - llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); @@ -1315,8 +1296,8 @@ std::vector & common_init_result::lora() { return pimpl->lora; } -common_init_result_ptr common_init_from_params(common_params & params) { - common_init_result_ptr res(new common_init_result(params)); +common_init_result_ptr common_init_from_params(common_params & params, bool model_only) { + common_init_result_ptr res(new common_init_result(params, model_only)); llama_model * model = res->model(); if (model == NULL) { @@ -1324,6 +1305,10 @@ common_init_result_ptr common_init_from_params(common_params & params) { return res; } + if (model_only) { + return res; + } + llama_context * lctx = res->context(); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); @@ -1387,7 +1372,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { } if (params.warmup) { - LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); llama_set_warmup(lctx, true); diff --git a/common/common.h b/common/common.h index e0f6b6780..23bc9d631 100644 --- a/common/common.h +++ b/common/common.h @@ -300,11 +300,11 @@ struct common_params_model { // draft-model-based speculative decoding parameters struct common_params_speculative_draft { - int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding - float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.0f; // minimum speculative decoding probability (greedy) common_params_model mparams; @@ -858,7 +858,7 @@ struct common_sampler; // note: defines the model, context, samplers, ets. lifetimes struct common_init_result { - common_init_result(common_params & params); + common_init_result(common_params & params, bool model_only = false); ~common_init_result(); llama_model * model(); @@ -876,7 +876,7 @@ private: using common_init_result_ptr = std::unique_ptr; -common_init_result_ptr common_init_from_params(common_params & params); +common_init_result_ptr common_init_from_params(common_params & params, bool model_only = false); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/common/hf-cache.cpp b/common/hf-cache.cpp index 20f33e4c7..ba7417a12 100644 --- a/common/hf-cache.cpp +++ b/common/hf-cache.cpp @@ -11,7 +11,6 @@ #include #include #include -#include // migration only #include #include #include @@ -336,15 +335,9 @@ hf_files get_repo_files(const std::string & repo_id, if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) { file.oid = item["lfs"]["oid"].get(); } - if (item["lfs"].contains("size") && item["lfs"]["size"].is_number()) { - file.size = item["lfs"]["size"].get(); - } } else if (item.contains("oid") && item["oid"].is_string()) { file.oid = item["oid"].get(); } - if (file.size == 0 && item.contains("size") && item["size"].is_number()) { - file.size = item["size"].get(); - } if (!file.oid.empty() && !is_valid_oid(file.oid)) { LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str()); @@ -502,271 +495,4 @@ std::string finalize_file(const hf_file & file) { return file.final_path; } -// delete everything after this line, one day - -// copied from download.cpp without the tag part -struct gguf_split_info { - std::string prefix; // tag included - int index; - int count; -}; - -static gguf_split_info get_gguf_split_info(const std::string & path) { - static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase); - std::smatch m; - - std::string prefix = path; - if (!string_remove_suffix(prefix, ".gguf")) { - return {}; - } - - int index = 1; - int count = 1; - - if (std::regex_match(prefix, m, re_split)) { - index = std::stoi(m[2].str()); - count = std::stoi(m[3].str()); - prefix = m[1].str(); - } - - return {std::move(prefix), index, count}; -} - -static std::pair parse_manifest_name(std::string & filename) { - static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)"); - std::smatch match; - if (std::regex_match(filename, match, re)) { - return {match[1].str(), match[2].str()}; - } - return {}; -} - -static std::string make_old_cache_filename(const std::string & owner, - const std::string & repo, - const std::string & filename) { - auto result = owner + "_" + repo + "_" + filename; - string_replace_all(result, "/", "_"); - return result; -} - -struct migrate_file { - std::string path; - std::string sha256; - size_t size; - fs::path old_path; - fs::path etag_path; - const hf_file * file; -}; - -using migrate_files = std::vector; - -static bool collect_file(const fs::path & old_cache, - const std::string & owner, - const std::string & repo, - const std::string & path, - const std::string & sha256, - const hf_files & files, - migrate_files & to_migrate) { - - const hf_file * file = nullptr; - - for (const auto & f : files) { - if (f.path == path) { - file = &f; - break; - } - } - - std::string old_filename = make_old_cache_filename(owner, repo, path); - fs::path old_path = old_cache / old_filename; - fs::path etag_path = old_path.string() + ".etag"; - - if (!fs::exists(old_path)) { - if (file && fs::exists(file->final_path)) { - return true; - } - LOG_WRN("%s: %s not found in old cache or HF cache\n", __func__, old_filename.c_str()); - return false; - } - - if (!file) { - LOG_WRN("%s: %s not found in current repo\n", __func__, old_filename.c_str()); - return false; - } - - if (!sha256.empty() && !file->oid.empty() && sha256 != file->oid) { - LOG_WRN("%s: %s is not up to date (sha256 mismatch)\n", __func__, old_filename.c_str()); - return false; - } - - if (file->size > 0) { - size_t size = fs::file_size(old_path); - if (size != file->size) { - LOG_WRN("%s: %s has wrong size %zu (expected %zu)\n", __func__, old_filename.c_str(), size, file->size); - return false; - } - } - - to_migrate.push_back({path, sha256, file->size, old_path, etag_path, file}); - return true; -} - -static bool collect_files(const fs::path & old_cache, - const std::string & owner, - const std::string & repo, - const nl::json & node, - const hf_files & files, - migrate_files & to_migrate) { - - if (!node.contains("rfilename") || - !node.contains("lfs") || - !node["lfs"].contains("sha256")) { - return true; - } - - std::string path = node["rfilename"]; - std::string sha256 = node["lfs"]["sha256"]; - - auto split = get_gguf_split_info(path); - - if (split.count <= 1) { - return collect_file(old_cache, owner, repo, path, sha256, files, to_migrate); - } - - std::vector> splits; - - for (const auto & f : files) { - auto split_f = get_gguf_split_info(f.path); - if (split_f.count == split.count && split_f.prefix == split.prefix) { - // sadly the manifest only provides the sha256 of the first file (index == 1) - // the rest will be verified using the size... - std::string f_sha256 = (split_f.index == 1) ? sha256 : ""; - splits.emplace_back(f.path, f_sha256); - } - } - - if ((int)splits.size() != split.count) { - LOG_WRN("%s: expected %d split files but found %d in repo\n", __func__, split.count, (int)splits.size()); - return false; - } - - for (const auto & [f_path, f_sha256] : splits) { - if (!collect_file(old_cache, owner, repo, f_path, f_sha256, files, to_migrate)) { - return false; - } - } - - return true; -} - -static bool migrate_file(const migrate_file & file) { - std::error_code ec; - - fs::path new_path(file.file->local_path); - fs::create_directories(new_path.parent_path(), ec); - - if (!fs::exists(new_path, ec)) { - fs::rename(file.old_path, new_path, ec); - if (ec) { - fs::copy_file(file.old_path, new_path, ec); - if (ec) { - LOG_ERR("%s: failed to move/copy %s: %s\n", __func__, file.old_path.string().c_str(), ec.message().c_str()); - return false; - } - } - fs::remove(file.old_path, ec); - } - fs::remove(file.etag_path, ec); - - std::string filename = finalize_file(*file.file); - LOG_INF("%s: migrated %s -> %s\n", __func__, file.old_path.filename().string().c_str(), filename.c_str()); - return true; -} - -void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) { - fs::path old_cache = fs_get_cache_directory(); - if (!fs::exists(old_cache)) { - return; - } - - if (offline) { - LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__); - return; // -hf is not going to work - } - - bool warned = false; - - for (const auto & entry : fs::directory_iterator(old_cache)) { - if (!entry.is_regular_file()) { - continue; - } - auto filename = entry.path().filename().string(); - auto [owner, repo] = parse_manifest_name(filename); - - if (owner.empty() || repo.empty()) { - continue; - } - - if (!warned) { - warned = true; - LOG_WRN("================================================================================\n" - "WARNING: Migrating cache to HuggingFace cache directory\n" - " Old cache: %s\n" - " New cache: %s\n" - "This one-time migration moves models previously downloaded with -hf\n" - "from the legacy llama.cpp cache to the standard HuggingFace cache.\n" - "Models downloaded with --model-url are not affected.\n" - "================================================================================\n", - old_cache.string().c_str(), get_cache_directory().string().c_str()); - } - - auto repo_id = owner + "/" + repo; - auto files = get_repo_files(repo_id, token); - - if (files.empty()) { - LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str()); - continue; - } - - migrate_files to_migrate; - bool ok = true; - - try { - std::ifstream manifest(entry.path()); - auto json = nl::json::parse(manifest); - for (const char * key : {"ggufFile", "mmprojFile"}) { - if (json.contains(key)) { - if (!collect_files(old_cache, owner, repo, json[key], files, to_migrate)) { - ok = false; - break; - } - } - } - } catch (const std::exception & e) { - LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what()); - continue; - } - - if (!ok) { - LOG_WRN("%s: migration skipped: one or more files failed validation\n", __func__); - continue; - } - - for (const auto & file : to_migrate) { - if (!migrate_file(file)) { - ok = false; - break; - } - } - - if (!ok) { - LOG_WRN("%s: migration failed: could not migrate all files\n", __func__); - continue; - } - - LOG_INF("%s: migration complete, deleting manifest: %s\n", __func__, entry.path().string().c_str()); - fs::remove(entry.path()); - } -} - } // namespace hf_cache diff --git a/common/hf-cache.h b/common/hf-cache.h index 9e46f9774..23fa0adb7 100644 --- a/common/hf-cache.h +++ b/common/hf-cache.h @@ -14,7 +14,6 @@ struct hf_file { std::string final_path; std::string oid; std::string repo_id; - size_t size = 0; // only for the migration }; using hf_files = std::vector; @@ -30,7 +29,4 @@ hf_files get_cached_files(const std::string & repo_id = {}); // Create snapshot path (link or move/copy) and return it std::string finalize_file(const hf_file & file); -// TODO: Remove later -void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false); - } // namespace hf_cache diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index 02bc482fe..936415976 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map, draft.push_back(inp[match_pos + n + i]); } - LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, + LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, key_offset, slot_max, curr_key.key_num, draft.size()); diff --git a/common/speculative.cpp b/common/speculative.cpp index 37b58d8af..723db5cf5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -32,6 +32,19 @@ const std::map common_speculative_type_fro {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; +static std::string common_speculative_get_devices_str(const std::vector & devices) { + if (devices.empty()) { + return "default"; + } + + std::string result; + for (size_t i = 0; i < devices.size(); i++) { + if (i > 0) result += ", "; + result += ggml_backend_dev_name(devices[i]); + } + return result; +} + struct common_speculative_config { common_speculative_type type; common_params_speculative params; @@ -144,7 +157,7 @@ struct common_speculative_impl { virtual void draft(common_speculative_draft_params_vec & dparams) = 0; - virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0; + virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0; // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; @@ -167,6 +180,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; auto * ctx_tgt = this->params.ctx_tgt; + LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min); + LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, + this->params.n_gpu_layers, + ggml_type_name(this->params.cache_type_k), + ggml_type_name(this->params.cache_type_v), + ctx_tgt ? "yes" : "no", + ctx_dft ? "yes" : "no", + common_speculative_get_devices_str(this->params.devices).c_str()); + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); // TODO: optimize or pass from outside? @@ -343,7 +366,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { } } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } @@ -355,8 +378,12 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl { struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { //common_params_speculative_eagle3 params; - common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq) - : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {} + common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq) + : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) + { + LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min); + } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop @@ -371,7 +398,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { // TODO: implement } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } @@ -380,7 +407,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { } }; -struct common_speculative_state_draft_mtp : public common_speculative_impl { +struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) llama_batch batch; @@ -407,7 +434,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { // pre-advancement before process() mirrored the verify batch. std::vector last_n_drafted; - common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq) + common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) { @@ -417,6 +444,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); + LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd); + LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, + this->params.n_gpu_layers, + ggml_type_name(this->params.cache_type_k), + ggml_type_name(this->params.cache_type_v), + ctx_tgt ? "yes" : "no", + ctx_dft ? "yes" : "no", + common_speculative_get_devices_str(this->params.devices).c_str()); + const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; MTP needs both. @@ -427,7 +464,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; - sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param + sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } @@ -446,7 +483,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { last_n_drafted.assign(n_seq, 0); } - ~common_speculative_state_draft_mtp() override { + ~common_speculative_impl_draft_mtp() override { if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; @@ -462,7 +499,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_max < N - 1) { - LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — " + LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " "Drafts may degrade.\n", @@ -633,6 +670,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { // add drafted token for each sequence const llama_token id = cur_p->data[0].id; + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + drafting[seq_id] = false; + n_drafting--; + + continue; + } + common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); @@ -678,7 +723,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl { } } - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override { if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { return; } @@ -714,7 +759,12 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { common_ngram_simple_config config) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) , params(params.ngram_simple) - , config(config) {} + , config(config) + { + LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__); + LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__, + this->params.size_n, this->params.size_m, this->params.min_hits); + } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop @@ -738,7 +788,7 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { } } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } @@ -748,20 +798,21 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl { }; struct common_speculative_impl_ngram_map_k : public common_speculative_impl { - common_params_speculative_ngram_map params; - // n_seq configs std::vector config; common_speculative_impl_ngram_map_k( - const common_params_speculative & params, const common_ngram_map & config, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq) - , params(params.ngram_map_k) { + { for (uint32_t i = 0; i < n_seq; i++) { this->config.push_back(config); } + + LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str()); + LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__, + config.size_key, config.size_value, config.key_only, config.min_hits); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { @@ -788,9 +839,13 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl { } } - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override { GGML_ASSERT((seq_id < (llama_seq_id) config.size())); + if (is_other) { + return; + } + common_ngram_map_accept(config[seq_id], n_accepted); } @@ -812,7 +867,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { // the last position in the prompt that was added to the ngram container size_t i_last = 0; - // length of the last drafted n‑gram (number of tokens returned by draft) + // length of the last drafted n-gram (number of tokens returned by draft) size_t n_draft_last = 0; // consecutive accept rounds with low acceptance fraction (< 0.5) @@ -830,8 +885,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); - LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__, - this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024); + LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__); + LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__, + this->params.n_match, this->params.n_max, this->params.n_min); + LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__, + mod.size(), (float)(mod.size_bytes())/1024/1024); if (this->params.n_match < 16) { LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " @@ -921,7 +979,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { } result.resize(result.size() - n); - // store length of drafted n‑gram for later acceptance analysis + // store length of drafted n-gram for later acceptance analysis sinfo.n_draft_last = result.size(); } @@ -943,17 +1001,21 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl { } } - void accept(llama_seq_id seq_id, uint16_t n_accepted) override { + void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override { + if (is_other) { + return; + } + auto & sinfo = sinfos[seq_id]; // compute acceptance fraction if we have a recorded draft length if (sinfo.n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; - if (f_acc < 0.5) { + if (f_acc < 0.25) { sinfo.n_low++; - if (sinfo.n_low >= 3) { + if (sinfo.n_low >= 5) { if (verbose) { - LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low); + LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); @@ -1003,6 +1065,12 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl { , save_dynamic(save_dynamic) , save_static(save_static) { + LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__); + LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__, + n_draft, + path_static.empty() ? "none" : path_static.c_str(), + path_dynamic.empty() ? "none" : path_dynamic.c_str()); + sinfos.resize(n_seq); if (!path_static.empty()) { @@ -1099,7 +1167,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl { } } - void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override { + void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } @@ -1285,7 +1353,6 @@ common_speculative * common_speculative_init(common_params_speculative & params, std::vector> impls = {}; for (const common_speculative_config & config : configs) { - LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str()); switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; @@ -1298,7 +1365,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, break; } case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: { - impls.push_back(std::make_unique(config.params, n_seq)); + impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { @@ -1319,11 +1386,16 @@ common_speculative * common_speculative_init(common_params_speculative & params, impls.push_back(std::move(state)); break; } - case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: { + impls.push_back( + std::make_unique( + get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { impls.push_back( std::make_unique( - config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); + get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { @@ -1515,11 +1587,6 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u GGML_ASSERT(impl); - // TODO: currently only the implementation that generated the draft is used to accept it - // however, some implementations (such as MTP) need to also "see" the accepted tokens - // extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to - // inform the implementation if the accepted tokens are from another implementation and - // pass the accepted tokens to all remaining implementations using `is_other == true` { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { @@ -1527,9 +1594,16 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u impl->n_acc_tokens += n_accepted; } - impl->accept(seq_id, n_accepted); + impl->accept(seq_id, n_accepted, false); impl->n_call_accept++; } + + // accept with the rest of the implementations, using is_other == true + for (auto & impl_other : spec->impls) { + if (impl_other.get() != impl) { + impl_other->accept(seq_id, n_accepted, true); + } + } } void common_speculative_print_stats(const common_speculative * spec) { @@ -1549,7 +1623,7 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf = ""; } - LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", + LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s\n", common_speculative_type_to_str(impl->type).c_str(), impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, impl->n_gen_drafts, diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ff8400508..1d18a1bf9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -115,15 +115,15 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--mmproj", action="store_true", - help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", + help="Export multimodal projector (mmproj) for vision models. This will only work on some vision models. An 'mmproj-' prefix will be added to the output file name.", ) parser.add_argument( "--mtp", action="store_true", - help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.", + help="Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. An 'mtp-' prefix will be added to the output file name.", ) parser.add_argument( "--no-mtp", action="store_true", - help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.", + help="Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, but even though the bundled default is more space-efficient overall, this allows differing quantization which may be more performant.", ) parser.add_argument( "--mistral-format", action="store_true", diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 1b7334617..81658ba03 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -445,6 +445,11 @@ if __name__ == '__main__': if self.lazy: tensor = LazyTorchTensor.from_eager(tensor) base_name = get_base_tensor_name(name) + # filter base name, ignore tensor transformations for now + data_gen = lambda g=tensor: g # noqa: E731 + if (titem := self.filter_tensors((base_name, data_gen))) is None: + continue + base_name, _ = titem # note: mergekit-extract-lora also adds token embeddings to the adapter is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index e833070ee..4bdd239c0 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -149,6 +149,8 @@ class TaskState: t_gen_ms: Optional[float] = None reasoning_content: Optional[str] = None server_name: Optional[str] = None + chunk_idx: int = 0 + problem_idx: int = 0 class EvalState: @@ -233,7 +235,9 @@ class EvalState: tps_gen: Optional[float] = None, t_gen_ms: Optional[float] = None, reasoning_content: Optional[str] = None, - server_name: Optional[str] = None + server_name: Optional[str] = None, + chunk_idx: int = 0, + problem_idx: int = 0, ): with self._lock: if "cases" not in self.task_states: @@ -252,7 +256,9 @@ class EvalState: "tps_gen": tps_gen, "t_gen_ms": t_gen_ms, "reasoning_content": reasoning_content, - "server_name": server_name + "server_name": server_name, + "chunk_idx": chunk_idx, + "problem_idx": problem_idx, } self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) @@ -289,6 +295,9 @@ class EvalState: all_cases = {} for i, task_id in tasks_to_save: question_text, prompt, expected = self.get_case(i) + # Extract chunk_idx from task_id for pending cases + _parts = task_id.rsplit("_", 2) + _chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0 if task_id in self.task_states.get("cases", {}): all_cases[task_id] = self.task_states["cases"][task_id] else: @@ -306,7 +315,9 @@ class EvalState: "tps_gen": None, "t_gen_ms": None, "reasoning_content": None, - "server_name": None + "server_name": None, + "chunk_idx": _chunk_idx, + "problem_idx": i, } ci_lower, ci_upper = self.accuracy_ci() @@ -382,11 +393,12 @@ class EvalState: grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) escaped_server = self._escape_html(server_name) + answer_class = status_class if status == "ok" else "" rows.append(f""" {task_id} {status_text} {self._escape_html(expected)} - {self._escape_html(answer)} + {self._escape_html(answer)} {tokens_str} {tps_str} {t_gen_str} @@ -405,6 +417,53 @@ class EvalState: rows_html = "\n".join(rows) + # ---- per-problem summary table ---- + problem_groups: Dict[int, List[Dict[str, Any]]] = {} + for _tid, _case in cases.items(): + if _case.get("status") != "ok": + continue + _pidx = _case.get("problem_idx") + if _pidx is None: + _p_parts = _tid.rsplit("_", 2) + _pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0 + problem_groups.setdefault(_pidx, []).append(_case) + + summary_rows_html = "" + if problem_groups: + def _stat(v, fmt=".1f", avg_fmt=None): + if not v: + return ("–", "–", "–") + af = fmt if avg_fmt is None else avg_fmt + return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}") + + summary_data = [] + for pidx, g in problem_groups.items(): + runs = len(g) + n_ok = sum(1 for c in g if c.get("correct", False)) + toks = [c["tokens"] for c in g if c.get("tokens") is not None] + tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None] + tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None] + summary_data.append(( + pidx, runs, n_ok, + _stat(toks, "d", ".0f"), + _stat(tps), + _stat(tg), + )) + + summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending + + summary_rows_html = "\n".join( + f""" + {p:03d} + {r} + {n}/{r} + {tk[0]}{tk[1]}{tk[2]} + {tp[0]}{tp[1]}{tp[2]} + {tg[0]}{tg[1]}{tg[2]} + """ + for p, r, n, tk, tp, tg in summary_data + ) + html_content = f""" @@ -412,10 +471,10 @@ class EvalState: {self.dataset_type.upper()} Eval
- {self.dataset_type.upper()} - Model: {self.model_name or 'N/A'} - Accuracy: {accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%] - Correct: {n_correct} / {len(completed)} - Pending: {n_pending} - Time: {self.total_time:.1f}s - Sampling: {sampling_str} +
Dataset
{self.dataset_type.upper()}
+
Model
{self.model_name or 'N/A'}
+
Accuracy
{accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]
+
Correct
{n_correct} / {len(completed)}
+
Pending
{n_pending}
+
Time
{self.total_time:.1f}s
+
Sampling
{sampling_str}
+
+
+ + +
+
+ + + + + + + + + + + + + + + {rows_html} + +
IDGoldAnswerTokensT/sGen sServer
+
+
+ + + + + + + + + + + + + + + + + + + + + {summary_rows_html} + +
ProblemRunsCorrectTokensT/sGen s
minavgmaxminavgmaxminavgmax
- - - - - - - - - - - - - - - {rows_html} - -
IDGoldAnswerTokensT/sGen sServer
""" @@ -1062,12 +1172,19 @@ class Processor: ) -> TaskState: question_text, prompt, expected = eval_state.get_case(i) + # Extract chunk_idx from task_id: "{dataset_type}_{chunk_idx:03d}_{index:03d}" + _parts = task_id.rsplit("_", 2) + chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0 + problem_idx = i + task_state = TaskState( task_id=task_id, prompt=prompt, expected=expected, question_text=question_text, - server_name=server_config.name + server_name=server_config.name, + chunk_idx=chunk_idx, + problem_idx=problem_idx, ) try: @@ -1085,7 +1202,8 @@ class Processor: eval_state.add_result( task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, - tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name, + chunk_idx, problem_idx, ) eval_state.dump() return task_state @@ -1108,7 +1226,8 @@ class Processor: eval_state.add_result( task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", - tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name, + chunk_idx, problem_idx, ) eval_state.dump() diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py index 2f9cdc545..e64ba8933 100644 --- a/examples/llama-eval/llama-server-simulator.py +++ b/examples/llama-eval/llama-server-simulator.py @@ -65,34 +65,70 @@ def normalize_number(s: str) -> Optional[int]: return int(match.group(0)) class AimeDataset: - def __init__(self, split: str = "train"): + def __init__(self, split: str = "train", dataset_type: str = "aime"): self.split = split + self.dataset_type = dataset_type self.questions: List[Dict] = [] self._load_dataset() - def _load_dataset(self): - print(f"Loading AIME dataset (split: {self.split})...") + def _get_question_text(self, question: Dict) -> str: + """Get question text, handling different dataset field names.""" + return question.get("problem", question.get("question", "")) - cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" - if cache_path.exists(): - print(f"Using cached dataset from {cache_path}") - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + def _load_dataset(self): + if self.dataset_type == "aime": + print(f"Loading AIME dataset (split: {self.split})...") + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + elif self.dataset_type == "aime2025": + print(f"Loading AIME2025 dataset...") + ds_list = [] + for config_name in ["AIME2025-I", "AIME2025-II"]: + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test") + ds_list.extend(ds) + ds = ds_list else: - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + raise ValueError(f"Unknown dataset type: {self.dataset_type}") self.questions = list(ds) - print(f"AIME dataset loaded: {len(self.questions)} questions") + print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions") def find_question(self, request_text: str) -> Optional[Dict]: + # Strip common template prefixes to get the actual question text + # Templates include things like "Solve the following math problem step by step..." + # The actual question usually follows a blank line or after the template instruction + cleaned = request_text + # Split on double newline and take the part that looks like the problem + parts = cleaned.split('\n\n') + if len(parts) > 1: + # Find the part that's longest (likely the actual problem text) + problem_parts = [p for p in parts if len(p.strip()) > 100] + if problem_parts: + cleaned = max(problem_parts, key=lambda x: len(x)) + best_match = None best_distance = -1 best_index = -1 for i, question in enumerate(self.questions): - question_text = question["problem"] - request_lower = request_text.lower() + question_text = self._get_question_text(question) + request_lower = cleaned.lower() question_lower = question_text.lower() + # Check if question text is contained in the cleaned request + if question_lower in request_lower or request_lower in question_lower: + debug_log(f"DEBUG: Found substring match at index {i}") + return question + # Exact match if question_lower == request_lower: debug_log(f"DEBUG: Found exact match at index {i}") @@ -118,7 +154,7 @@ class AimeDataset: debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") return best_match - debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...") + debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...") return None def get_answer(self, question: Dict) -> str: @@ -134,15 +170,16 @@ class Simulator: port: int = 8033, host: str = "localhost", success_rate: float = 0.8, - dataset_split: str = "train" + dataset_split: str = "train", + dataset_type: str = "aime" ): self.port = port self.host = host self.success_rate = success_rate - self.dataset = AimeDataset(dataset_split) + self.dataset = AimeDataset(dataset_split, dataset_type) self.eval_state = EvalState( - id="aime-2025", - tasks=["aime"], + id=dataset_type, + tasks=[dataset_type], task_states={}, sampling_config={"temperature": 0, "max_tokens": 2048} ) @@ -159,6 +196,10 @@ class Simulator: else: response_text = self._generate_wrong_answer(question) + comp_tokens = random.randint(10000, 60000) + tps_gen = random.uniform(90.0, 110.0) + t_gen_ms = comp_tokens / tps_gen * 1000 + return { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", @@ -176,8 +217,12 @@ class Simulator: ], "usage": { "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150 + "completion_tokens": comp_tokens, + "total_tokens": 100 + comp_tokens + }, + "timings": { + "predicted_ms": t_gen_ms, + "predicted_per_second": tps_gen } } @@ -218,6 +263,12 @@ class Simulator: return response class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/v1/models": + self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200) + return + self._send_json({"error": "Not found"}, 404) + def do_POST(self): if self.path != "/v1/chat/completions": self._send_json({"error": "Not found"}, 404) @@ -280,6 +331,13 @@ def main(): default=0.8, help="Success rate 0-1 (default: 0.8)" ) + parser.add_argument( + "--dataset", + type=str, + default="aime", + choices=["aime", "aime2025"], + help="Dataset type (default: aime)" + ) parser.add_argument( "--dataset-split", type=str, @@ -294,7 +352,8 @@ def main(): port=args.port, host=args.host, success_rate=args.success_rate, - dataset_split=args.dataset_split + dataset_split=args.dataset_split, + dataset_type=args.dataset ) server = HTTPServer((args.host, args.port), RequestHandler) @@ -304,7 +363,7 @@ def main(): print("\n=== llama-server-simulator ===") print(f"Server running on http://{args.host}:{args.port}") print(f"Success rate: {args.success_rate}") - print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions") + print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions") print("\nPress Ctrl+C to stop\n") try: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a..73a0991e2 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + return 8; case GGML_TYPE_Q6_K: + return 2; case GGML_TYPE_IQ4_NL: return 8; default: diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e288a27f9..ba006d9b3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2..8506000b6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const int nth = MIN(args.ne00, nth_max); - const int nk0 = (args.ne00 + nth - 1)/nth; ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); @@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f6ffb2b3a..4cf9dbea9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - if (K > 1u) { + if (K > 1) { const int target_slot = (int)t - shift; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; @@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl( } } - if (K == 1u) { + if (K == 1) { device float * dst_state = (device float *) (dst) + attn_size + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; @@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32( for (int64_t sx = x_min; sx < x_max; ++sx) { const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); const float w = wx * wy; - const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); sum += (*src_ptr) * w; wsum += w; } @@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32( const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; - const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); sum += (*src_ptr) * wx * wy; } } @@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32( } } -kernel void kernel_pad_f32( +template +kernel void kernel_pad_impl( constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; } - return; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1cb8f563d..d38057721 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -211,7 +219,6 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->last_graph_uid = cgraph->uid; + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { /* .endpoint = */ endpoint, /* .device = */ device, /* .name = */ dev_name, - /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { diff --git a/src/llama-graph.h b/src/llama-graph.h index 9e55d0a67..bf6778237 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -581,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index a59561ea5..72f5c2fea 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index fd305cab7..33b3b395e 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 549188990..dacb2933d 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -416,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 29c58afc9..b13b7b748 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -72,6 +72,7 @@ public: // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups uint32_t n_rs_seq = 0; + // per-seq rollback index std::vector rs_idx; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 2a4e00384..4f4c7cac7 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -447,13 +447,6 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } -bool llm_build_delta_net_base::keep_rs() const { - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - return cparams.n_rs_seq > 0 - && n_seq_tokens > 1 - && (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq; -} - ggml_tensor * llm_build_delta_net_base::build_conv_state( llm_graph_input_rs * inp, ggml_tensor * conv_states_all, @@ -461,12 +454,12 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state( int64_t conv_kernel_size, int64_t conv_channels, int il) { - const auto * mctx_cur = inp->mctx; - const auto kv_head = mctx_cur->get_head(); - const uint32_t mem_size = mctx_cur->get_size(); - const int64_t n_seqs = ubatch.n_seqs; - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const bool keep = keep_rs(); + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -480,32 +473,52 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state( ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); - if (!keep) { - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); } else { - const int64_t row_count = (conv_kernel_size - 1) * conv_channels; - const size_t row_size = row_count * ggml_element_size(conv_states_all); - for (int64_t t = 1; t <= n_seq_tokens; ++t) { - const uint32_t slot = (uint32_t)(n_seq_tokens - t); - ggml_tensor * src = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, - conv_input->nb[1], conv_input->nb[2], - t * ggml_element_size(conv_input)); - ggml_tensor * dst = - ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs, - conv_states_all->nb[1], - ((size_t) slot * mem_size + kv_head) * row_size); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); } } @@ -531,7 +544,9 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t n_seqs = s->ne[3]; const int64_t n_seq_tokens = q->ne[2]; - if (!keep_rs()) { + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { auto attn_out = build_delta_net(q, k, v, g, b, s, il); ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -547,14 +562,18 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( } const int64_t D = S_v * S_v * H_v; - const int64_t K = (int64_t) cparams.n_rs_seq + 1; + const int64_t K = cparams.n_rs_seq + 1; // TODO: remove pad + simplify - ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); - ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0); + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); - ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d); - cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; @@ -576,9 +595,11 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( ggml_row_size(gdn_out->type, S_v * S_v), ggml_row_size(gdn_out->type, S_v * S_v * H_v), ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], ((size_t) cache_slot * mem_size + kv_head) * row_size); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); } diff --git a/src/models/models.h b/src/models/models.h index 4e40536a5..7e551eb96 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -66,9 +66,6 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // true when speculative rollback is enabled and the batch fits in the rs cache - bool keep_rs() const; - // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) ggml_tensor * build_conv_state( diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 361d7538a..35a0158e8 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -496,7 +496,8 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - // The MTP block lives at the source file's original layer index. + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; const auto & layer = model.layers[il]; diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index bcdf44040..20a5ff1eb 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -30,7 +30,7 @@ int main(int argc, char ** argv) { if (!params.fit_params_print) { const common_params_fit_status status = common_fit_params(params.model.path.c_str(), &mparams, &cparams, params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, - params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); + params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); if (status != COMMON_PARAMS_FIT_STATUS_SUCCESS) { LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__); exit(1); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0f3fb9efa..dc3189e17 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -467,20 +467,26 @@ struct server_slot { const double n_gen_second = 1e3 / t_token_generation * n_decoded; SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second); + + SLT_INF(*this, + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + t_token_generation, n_decoded, t_gen, n_gen_second); + + SLT_INF(*this, " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + SLT_INF(*this, + " graphs reused = %10d\n", + llama_perf_context(ctx_tgt).n_reused); + if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_CNT(*this, - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); + SLT_INF(*this, + "draft acceptance = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total); } common_speculative_print_stats(spec); @@ -2583,9 +2589,9 @@ private: llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, pos_next - n_swa); + const auto pos_min_thold = std::max(0, pos_next - n_swa - 1); - if (n_past > 0 && n_past < slot.prompt.n_tokens()) { + if (n_past > 0 && n_past <= slot.prompt.n_tokens()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); @@ -3885,6 +3891,7 @@ void server_routes::init_routes() { { "eos_token", meta->eos_token_str }, { "build_info", meta->build_info }, { "is_sleeping", queue_tasks.is_sleeping() }, + { "cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy }, }; if (params.use_jinja) { if (!tmpl_tools.empty()) { diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6c6fed52d..ccf42320f 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1165,6 +1165,7 @@ void server_models_routes::init_routes() { // Deprecated: use ui_settings instead (kept for backward compat) {"webui_settings", webui_settings}, {"build_info", std::string(llama_build_info())}, + {"cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy}, }); return res; } diff --git a/tools/ui/.env.example b/tools/ui/.env.example new file mode 100644 index 000000000..9a995b746 --- /dev/null +++ b/tools/ui/.env.example @@ -0,0 +1,2 @@ +VITE_PUBLIC_APP_NAME='llama-ui' +# VITE_DEBUG='true' diff --git a/tools/ui/.gitignore b/tools/ui/.gitignore index 051d884b0..22ed6125f 100644 --- a/tools/ui/.gitignore +++ b/tools/ui/.gitignore @@ -25,4 +25,4 @@ vite.config.ts.timestamp-* *storybook.log storybook-static -*.code-workspace \ No newline at end of file +*.code-workspace diff --git a/tools/ui/eslint.config.js b/tools/ui/eslint.config.js index 185da1dab..4ed9dd7ca 100644 --- a/tools/ui/eslint.config.js +++ b/tools/ui/eslint.config.js @@ -20,9 +20,7 @@ export default ts.config( prettier, ...svelte.configs.prettier, { - languageOptions: { - globals: { ...globals.browser, ...globals.node } - }, + languageOptions: { globals: { ...globals.browser, ...globals.node } }, rules: { // typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects. // see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors @@ -30,6 +28,7 @@ export default ts.config( 'svelte/no-at-html-tags': 'off', // This app uses hash-based routing (#/) where resolve() from $app/paths does not apply 'svelte/no-navigation-without-resolve': 'off', + // Enforce empty line at end of file 'eol-last': 'error' } diff --git a/tools/ui/package-lock.json b/tools/ui/package-lock.json index bf23307b8..4d012c819 100644 --- a/tools/ui/package-lock.json +++ b/tools/ui/package-lock.json @@ -2307,9 +2307,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.59.1", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.59.1.tgz", - "integrity": "sha512-d8OON70AphLdDesuTIl//M2O6fRTIicX8aYv8vhCiYEhTTI2OboKqey0Hu1A4VFhqwgqtq0vKDmPFGkw8kKmgw==", + "version": "2.60.1", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.60.1.tgz", + "integrity": "sha512-mQjlkNo+rJvpln7V2IGY2j99BqhcFbS4UN0AQNKNYfhBAFZTuCDAdW3a1sgf330mvtNvsBXn3HpAhcmvdJTcIQ==", "dev": true, "license": "MIT", "dependencies": { @@ -2318,7 +2318,7 @@ "@types/cookie": "^0.6.0", "acorn": "^8.14.1", "cookie": "^0.6.0", - "devalue": "^5.6.4", + "devalue": "^5.8.1", "esm-env": "^1.2.2", "kleur": "^4.1.5", "magic-string": "^0.30.5", @@ -4296,9 +4296,9 @@ } }, "node_modules/devalue": { - "version": "5.6.4", - "resolved": "https://registry.npmjs.org/devalue/-/devalue-5.6.4.tgz", - "integrity": "sha512-Gp6rDldRsFh/7XuouDbxMH3Mx8GMCcgzIb1pDTvNyn8pZGQ22u+Wa+lGV9dQCltFQ7uVw0MhRyb8XDskNFOReA==", + "version": "5.8.1", + "resolved": "https://registry.npmjs.org/devalue/-/devalue-5.8.1.tgz", + "integrity": "sha512-4CXDYRBGqN+57wVJkuXBYmpAVUSg3L6JAQa/DFqm238G73E1wuyc/JhGQJzN7vUf/CMphYau2zXbfWzDR5aTEw==", "license": "MIT" }, "node_modules/devlop": { @@ -4856,12 +4856,12 @@ } }, "node_modules/express-rate-limit": { - "version": "8.5.0", - "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.0.tgz", - "integrity": "sha512-XKhFohWaSBdVJNTi5TaHziqnPkv04I9UQV6q1Wy7Ui6GGQZVW12ojDFwqer14EvCXxjvPG0CyWXx7cAXpALB4Q==", + "version": "8.5.2", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.2.tgz", + "integrity": "sha512-5Kb34ipNX694DH48vN9irak1Qx30nb0PLYHXfJgw4YEjiC3ZEmZJhwOp+VfiCYwFzvFTdB9QkArYS5kXa2cx2A==", "license": "MIT", "dependencies": { - "ip-address": "10.1.0" + "ip-address": "^10.2.0" }, "engines": { "node": ">= 16" @@ -4909,9 +4909,9 @@ "license": "MIT" }, "node_modules/fast-uri": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", - "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.2.tgz", + "integrity": "sha512-rVjf7ArG3LTk+FS6Yw81V1DLuZl1bRbNrev6Tmd/9RaroeeRRJhAt7jg/6YFxbvAQXUCavSoZhPPj6oOx+5KjQ==", "funding": [ { "type": "github", @@ -5541,9 +5541,9 @@ } }, "node_modules/hono": { - "version": "4.12.14", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.14.tgz", - "integrity": "sha512-am5zfg3yu6sqn5yjKBNqhnTX7Cv+m00ox+7jbaKkrLMRJ4rAdldd1xPd/JzbBWspqaQv6RSTrgFN95EsfhC+7w==", + "version": "4.12.19", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.19.tgz", + "integrity": "sha512-xa3eYXYXx68XTT4hZ7dRzsXBhaq85ToSrlUJNoR0gwz/1Ap/CNwX47wfvV7pc/xWhjKVVkLT7zBJy8chhNguqQ==", "license": "MIT", "engines": { "node": ">=16.9.0" @@ -5722,9 +5722,9 @@ "license": "MIT" }, "node_modules/ip-address": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz", - "integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==", + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz", + "integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==", "license": "MIT", "engines": { "node": ">= 12" @@ -6008,9 +6008,9 @@ } }, "node_modules/katex": { - "version": "0.16.22", - "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.22.tgz", - "integrity": "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg==", + "version": "0.16.47", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.47.tgz", + "integrity": "sha512-Eeo8Ys1doU1z+x8AZsPpQu+p/QcZBI5PeOo7QGQdy2x2m0MU/hYagBbGOmXwr5KVbEfVuWv9LpnQWeehogurjg==", "dev": true, "funding": [ "https://opencollective.com/katex", @@ -9245,9 +9245,9 @@ } }, "node_modules/svelte": { - "version": "5.55.1", - "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.1.tgz", - "integrity": "sha512-QjvU7EFemf6mRzdMGlAFttMWtAAVXrax61SZYHdkD6yoVGQ89VeyKfZD4H1JrV1WLmJBxWhFch9H6ig/87VGjw==", + "version": "5.55.7", + "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.7.tgz", + "integrity": "sha512-ymI5ykLPwIHW839E053FQbI1G+jnRFJEw3Kv5Y4njixVWywQBx+NUFpkkKyk5LIb36Fg9DVXSYpqiGekLD0hyw==", "license": "MIT", "dependencies": { "@jridgewell/remapping": "^2.3.4", @@ -9259,7 +9259,7 @@ "aria-query": "5.3.1", "axobject-query": "^4.1.0", "clsx": "^2.1.1", - "devalue": "^5.6.4", + "devalue": "^5.8.1", "esm-env": "^1.2.1", "esrap": "^2.2.4", "is-reference": "^3.0.3", @@ -10606,9 +10606,9 @@ "license": "ISC" }, "node_modules/ws": { - "version": "8.18.3", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", - "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", + "version": "8.20.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.20.1.tgz", + "integrity": "sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w==", "dev": true, "license": "MIT", "engines": { diff --git a/tools/ui/src/app.css b/tools/ui/src/app.css index d6dc6670c..29b1d3c64 100644 --- a/tools/ui/src/app.css +++ b/tools/ui/src/app.css @@ -1,6 +1,7 @@ @import 'tailwindcss'; -@source "."; - +@source '.'; +@plugin '@tailwindcss/forms'; +@plugin '@tailwindcss/typography'; @import 'tw-animate-css'; @custom-variant dark (&:is(.dark *)); diff --git a/tools/ui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsPreview/ChatAttachmentsPreviewCurrentItem/ChatAttachmentsPreviewCurrentItemVideo.svelte b/tools/ui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsPreview/ChatAttachmentsPreviewCurrentItem/ChatAttachmentsPreviewCurrentItemVideo.svelte index 4ebbd5922..62040b36f 100644 --- a/tools/ui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsPreview/ChatAttachmentsPreviewCurrentItem/ChatAttachmentsPreviewCurrentItemVideo.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsPreview/ChatAttachmentsPreviewCurrentItem/ChatAttachmentsPreviewCurrentItemVideo.svelte @@ -15,6 +15,7 @@ {#if videoSrc} {:else} diff --git a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionModels.svelte b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionModels.svelte index 2f9471e0d..297020605 100644 --- a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionModels.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionModels.svelte @@ -7,7 +7,6 @@ import { activeMessages } from '$lib/stores/conversations.svelte'; interface Props { - currentModel?: string; disabled?: boolean; forceForegroundText?: boolean; hasAudioModality?: boolean; @@ -20,7 +19,6 @@ } let { - currentModel, disabled = false, forceForegroundText = false, hasAudioModality = $bindable(false), @@ -41,14 +39,28 @@ let lastSyncedConversationModel: string | null = null; + let selectorModel = $derived(conversationModel ?? modelsStore.selectedModelName ?? null); + $effect(() => { if (conversationModel && conversationModel !== lastSyncedConversationModel) { - lastSyncedConversationModel = conversationModel; + if (modelOptions().some((m) => m.model === conversationModel)) { + modelsStore.selectedModelName = conversationModel; + modelsStore.selectModelByName(conversationModel); + } else { + modelsStore.selectedModelName = null; + modelsStore.clearSelection(); + } - modelsStore.selectModelByName(conversationModel); - } else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) { + lastSyncedConversationModel = conversationModel; + } else if ( + isRouter && + !modelsStore.selectedModelId && + modelsStore.loadedModelIds.length > 0 && + activeMessages().length > 0 && + !conversationModel + ) { lastSyncedConversationModel = null; - // auto-select the first loaded model only when nothing is selected yet + const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model)); if (first) modelsStore.selectModelById(first.id); @@ -151,7 +163,7 @@ @@ -159,7 +171,7 @@ diff --git a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormPickers/ChatFormPickerMcpPrompts/ChatFormPickerMcpPrompts.svelte b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormPickers/ChatFormPickerMcpPrompts/ChatFormPickerMcpPrompts.svelte index 567fdac47..ff734ac88 100644 --- a/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormPickers/ChatFormPickerMcpPrompts/ChatFormPickerMcpPrompts.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatForm/ChatFormPickers/ChatFormPickerMcpPrompts/ChatFormPickerMcpPrompts.svelte @@ -162,7 +162,7 @@ return; } - if (import.meta.env.DEV) { + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log('[ChatFormPickerMcpPrompts] Fetching completions for:', { serverName: selectedPrompt.serverName, promptName: selectedPrompt.name, @@ -181,7 +181,7 @@ value ); - if (import.meta.env.DEV) { + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log('[ChatFormPickerMcpPrompts] Autocomplete result:', { argName, value, diff --git a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index dc3eab134..e733a64a9 100644 --- a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -1,20 +1,20 @@ + +
+ +
diff --git a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenGreeting.svelte b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenGreeting.svelte new file mode 100644 index 000000000..141d4f4e4 --- /dev/null +++ b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenGreeting.svelte @@ -0,0 +1,25 @@ + + + diff --git a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenProcessingInfo.svelte b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenProcessingInfo.svelte index b5979db13..f38f3519c 100644 --- a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenProcessingInfo.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenProcessingInfo.svelte @@ -6,6 +6,7 @@ import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte'; import { config } from '$lib/stores/settings.svelte'; import { getProcessingInfoContext } from '$lib/contexts'; + import { page } from '$app/state'; const processingState = useProcessingState(); const processingInfoCtx = getProcessingInfoContext(); @@ -16,6 +17,14 @@ let isStreaming = $derived(isChatStreaming()); let processingDetails = $derived(processingState.getTechnicalDetails()); + let processingVisible = $derived(processingDetails.length > 0); + + let { onVisibilityChange }: { onVisibilityChange?: (visible: boolean) => void } = $props(); + + $effect(() => { + onVisibilityChange?.(processingVisible); + }); + $effect(() => { const conversation = activeConversation(); @@ -60,9 +69,12 @@
-
+
{#each processingDetails as detail (detail)} {detail} {/each} diff --git a/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenServerError.svelte b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenServerError.svelte new file mode 100644 index 000000000..2a998dbeb --- /dev/null +++ b/tools/ui/src/lib/components/app/chat/ChatScreen/ChatScreenServerError.svelte @@ -0,0 +1,34 @@ + + +{#if hasError} +
+ + + + + Server unavailable + + + + + {serverError()} + +
+{/if} diff --git a/tools/ui/src/lib/components/app/chat/index.ts b/tools/ui/src/lib/components/app/chat/index.ts index 5f6597980..be5535960 100644 --- a/tools/ui/src/lib/components/app/chat/index.ts +++ b/tools/ui/src/lib/components/app/chat/index.ts @@ -667,3 +667,17 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte'; * Only visible when `isCurrentConversationLoading` is true. */ export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte'; + +/** + * Scroll-to-bottom action button. Displays a floating button when the user + * has scrolled up more than half a viewport height from the bottom. + * Takes the chat container element as a prop to manage scroll state internally. + */ +export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte'; + +/** + * Server error alert displayed when the server is unreachable. + * Shows the error message with a retry button. + * Rendered inside ChatScreen when `serverError` store has a value. + */ +export { default as ChatScreenServerError } from './ChatScreen/ChatScreenServerError.svelte'; diff --git a/tools/ui/src/lib/components/app/content/MarkdownContent/MarkdownContent.svelte b/tools/ui/src/lib/components/app/content/MarkdownContent/MarkdownContent.svelte index 3a11854b6..0412414ae 100644 --- a/tools/ui/src/lib/components/app/content/MarkdownContent/MarkdownContent.svelte +++ b/tools/ui/src/lib/components/app/content/MarkdownContent/MarkdownContent.svelte @@ -28,7 +28,7 @@ SETTINGS_KEYS } from '$lib/constants'; import { ColorMode, UrlProtocol } from '$lib/enums'; - import { FileTypeText } from '$lib/enums/files'; + import { FileTypeText } from '$lib/enums/files.enums'; import { highlightCode, detectIncompleteCodeBlock, type IncompleteCodeBlock } from '$lib/utils'; import '$styles/katex-custom.scss'; import githubDarkCss from 'highlight.js/styles/github-dark.css?inline'; diff --git a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte index 109c8ff9d..d017fe204 100644 --- a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte +++ b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte @@ -17,7 +17,7 @@ } from '$lib/constants'; import { RouterService } from '$lib/services/router.service'; import { setMode } from 'mode-watcher'; - import { ColorMode } from '$lib/enums/ui'; + import { ColorMode } from '$lib/enums/ui.enums'; import { fade } from 'svelte/transition'; import { goto } from '$app/navigation'; import { page } from '$app/state'; diff --git a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte index 069855eeb..7c1c5c897 100644 --- a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte +++ b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChatFields.svelte @@ -6,7 +6,7 @@ import * as Select from '$lib/components/ui/select'; import { Textarea } from '$lib/components/ui/textarea'; import { SETTING_CONFIG_INFO, SETTINGS_KEYS } from '$lib/constants'; - import { SettingsFieldType } from '$lib/enums/settings'; + import { SettingsFieldType } from '$lib/enums/settings.enums'; import { settingsStore } from '$lib/stores/settings.svelte'; import { serverStore } from '$lib/stores/server.svelte'; import { modelsStore, selectedModelName, propsCacheVersion } from '$lib/stores/models.svelte'; diff --git a/tools/ui/src/lib/constants/mcp.ts b/tools/ui/src/lib/constants/mcp.ts index 19bdd92ea..918eb9f94 100644 --- a/tools/ui/src/lib/constants/mcp.ts +++ b/tools/ui/src/lib/constants/mcp.ts @@ -2,7 +2,7 @@ import { Zap, Globe, Radio } from '@lucide/svelte'; import { MCPTransportType } from '$lib/enums'; import type { ClientCapabilities, Implementation } from '$lib/types'; import type { Component } from 'svelte'; -import { MimeTypeImage } from '$lib/enums/files'; +import { MimeTypeImage } from '$lib/enums/files.enums'; export const DEFAULT_CLIENT_VERSION = '1.0.0'; export const MCP_CLIENT_NAME = 'llama-ui-mcp'; diff --git a/tools/ui/src/lib/constants/settings-registry.ts b/tools/ui/src/lib/constants/settings-registry.ts index bdbb17d96..93b3cd5ed 100644 --- a/tools/ui/src/lib/constants/settings-registry.ts +++ b/tools/ui/src/lib/constants/settings-registry.ts @@ -1,5 +1,5 @@ -import { ColorMode } from '$lib/enums/ui'; -import { SettingsFieldType } from '$lib/enums/settings'; +import { ColorMode } from '$lib/enums/ui.enums'; +import { SettingsFieldType } from '$lib/enums/settings.enums'; import { SyncableParameterType } from '$lib/enums'; import { Funnel, diff --git a/tools/ui/src/lib/constants/supported-file-types.ts b/tools/ui/src/lib/constants/supported-file-types.ts index 345054389..414116154 100644 --- a/tools/ui/src/lib/constants/supported-file-types.ts +++ b/tools/ui/src/lib/constants/supported-file-types.ts @@ -18,7 +18,7 @@ import { MimeTypeApplication, MimeTypeText } from '$lib/enums'; -import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files'; +import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files.enums'; // File type configuration using enums export const AUDIO_FILE_TYPES = { diff --git a/tools/ui/src/lib/constants/tools.ts b/tools/ui/src/lib/constants/tools.ts index 22b22309c..efc3476cd 100644 --- a/tools/ui/src/lib/constants/tools.ts +++ b/tools/ui/src/lib/constants/tools.ts @@ -1,4 +1,4 @@ -import { ToolSource } from '$lib/enums/tools'; +import { ToolSource } from '$lib/enums/tools.enums'; export const TOOL_GROUP_LABELS = { [ToolSource.BUILTIN]: 'Built-in', diff --git a/tools/ui/src/lib/enums/agentic.ts b/tools/ui/src/lib/enums/agentic.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/agentic.ts rename to tools/ui/src/lib/enums/agentic.enums.ts diff --git a/tools/ui/src/lib/enums/attachment.ts b/tools/ui/src/lib/enums/attachment.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/attachment.ts rename to tools/ui/src/lib/enums/attachment.enums.ts diff --git a/tools/ui/src/lib/enums/chat.ts b/tools/ui/src/lib/enums/chat.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/chat.ts rename to tools/ui/src/lib/enums/chat.enums.ts diff --git a/tools/ui/src/lib/enums/files.ts b/tools/ui/src/lib/enums/files.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/files.ts rename to tools/ui/src/lib/enums/files.enums.ts diff --git a/tools/ui/src/lib/enums/index.ts b/tools/ui/src/lib/enums/index.ts index 3cf81286b..a17cca1d8 100644 --- a/tools/ui/src/lib/enums/index.ts +++ b/tools/ui/src/lib/enums/index.ts @@ -4,9 +4,9 @@ export { AttachmentItemEnabledWhen, AttachmentAction, AttachmentItemVisibleWhen -} from './attachment'; +} from './attachment.enums'; -export { AgenticSectionType, ToolCallType } from './agentic'; +export { AgenticSectionType, ToolCallType } from './agentic.enums'; export { ChatMessageStatsView, @@ -17,7 +17,7 @@ export { MessageType, PdfViewMode, ReasoningFormat -} from './chat'; +} from './chat.enums'; export { FileTypeCategory, @@ -38,7 +38,7 @@ export { MimeTypeImage, MimeTypeText, SpecialFileType -} from './files'; +} from './files.enums'; export { MCPConnectionPhase, @@ -48,16 +48,16 @@ export { MCPContentType, MCPRefType, JsonSchemaType -} from './mcp'; +} from './mcp.enums'; -export { ModelModality } from './model'; +export { ModelModality } from './model.enums'; -export { ServerRole, ServerModelStatus } from './server'; +export { ServerRole, ServerModelStatus } from './server.enums'; -export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings'; +export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums'; -export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui'; +export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui.enums'; -export { KeyboardKey } from './keyboard'; +export { KeyboardKey } from './keyboard.enums'; -export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools'; +export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools.enums'; diff --git a/tools/ui/src/lib/enums/keyboard.ts b/tools/ui/src/lib/enums/keyboard.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/keyboard.ts rename to tools/ui/src/lib/enums/keyboard.enums.ts diff --git a/tools/ui/src/lib/enums/mcp.ts b/tools/ui/src/lib/enums/mcp.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/mcp.ts rename to tools/ui/src/lib/enums/mcp.enums.ts diff --git a/tools/ui/src/lib/enums/model.ts b/tools/ui/src/lib/enums/model.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/model.ts rename to tools/ui/src/lib/enums/model.enums.ts diff --git a/tools/ui/src/lib/enums/server.ts b/tools/ui/src/lib/enums/server.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/server.ts rename to tools/ui/src/lib/enums/server.enums.ts diff --git a/tools/ui/src/lib/enums/settings.ts b/tools/ui/src/lib/enums/settings.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/settings.ts rename to tools/ui/src/lib/enums/settings.enums.ts diff --git a/tools/ui/src/lib/enums/tools.ts b/tools/ui/src/lib/enums/tools.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/tools.ts rename to tools/ui/src/lib/enums/tools.enums.ts diff --git a/tools/ui/src/lib/enums/ui.ts b/tools/ui/src/lib/enums/ui.enums.ts similarity index 100% rename from tools/ui/src/lib/enums/ui.ts rename to tools/ui/src/lib/enums/ui.enums.ts diff --git a/tools/ui/src/lib/hooks/use-auto-scroll.svelte.ts b/tools/ui/src/lib/hooks/use-auto-scroll.svelte.ts index f59e3ed4b..7bac452e4 100644 --- a/tools/ui/src/lib/hooks/use-auto-scroll.svelte.ts +++ b/tools/ui/src/lib/hooks/use-auto-scroll.svelte.ts @@ -100,6 +100,14 @@ export class AutoScrollController { this._autoScrollEnabled = true; } + /** + * Resets scroll state when switching conversations. + */ + resetScrollState(): void { + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + /** * Starts the auto-scroll interval for continuous scrolling during streaming. */ diff --git a/tools/ui/src/lib/hooks/use-models-selector.svelte.ts b/tools/ui/src/lib/hooks/use-models-selector.svelte.ts index 537a2af18..098cb2c27 100644 --- a/tools/ui/src/lib/hooks/use-models-selector.svelte.ts +++ b/tools/ui/src/lib/hooks/use-models-selector.svelte.ts @@ -66,7 +66,6 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele const serverModel = $derived(singleModelName()); const currentModel = $derived(opts.currentModel()); - const useGlobalSelection = $derived(opts.useGlobalSelection?.() ?? false); const onModelChange = $derived(opts.onModelChange?.()); const isHighlightedCurrentModelActive = $derived.by(() => { @@ -128,6 +127,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele if (onModelChange) { const result = await onModelChange(option.id, option.model); + if (result === false) { shouldCloseMenu = false; } @@ -142,12 +142,14 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele const textarea = document.querySelector( '[data-slot="chat-form"] textarea' ); + textarea?.focus(); }); } if (!onModelChange && isRouter && !modelsStore.isModelLoaded(option.model)) { isLoadingModel = true; + modelsStore .loadModel(option.model) .catch((error) => console.error('Failed to load model:', error)) @@ -158,6 +160,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele function getDisplayOption(): ModelOption | undefined { if (!isRouter) { const displayModel = serverModel || currentModel; + if (displayModel) { return { id: serverModel ? 'current' : 'offline-current', @@ -166,12 +169,8 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele capabilities: [] }; } - return undefined; - } - if (useGlobalSelection && activeId) { - const selected = options.find((option) => option.id === activeId); - if (selected) return selected; + return undefined; } if (currentModel) { @@ -183,6 +182,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele capabilities: [] }; } + return options.find((option) => option.model === currentModel); } @@ -197,57 +197,77 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele get options() { return options; }, + get loading() { return loading; }, + get updating() { return updating; }, + get activeId() { return activeId; }, + get isRouter() { return isRouter; }, + get serverModel() { return serverModel; }, + get isHighlightedCurrentModelActive() { return isHighlightedCurrentModelActive; }, + get isCurrentModelInCache() { return isCurrentModelInCache; }, + get filteredOptions() { return filteredOptions; }, + get groupedFilteredOptions() { return groupedFilteredOptions; }, + get isLoadingModel() { return isLoadingModel; }, + get searchTerm() { return searchTerm; }, + get showModelDialog() { return showModelDialog; }, + get infoModelId() { return infoModelId; }, + setSearchTerm(value: string) { searchTerm = value; }, + setShowModelDialog(value: boolean) { showModelDialog = value; }, + handleInfoClick, + handleSelect, + handleOpenChange, + isFavorite(model: string) { return modelsStore.favoriteModelIds.has(model); }, + getDisplayOption }; } diff --git a/tools/ui/src/lib/services/mcp.service.ts b/tools/ui/src/lib/services/mcp.service.ts index 44cbd4a8a..d596381aa 100644 --- a/tools/ui/src/lib/services/mcp.service.ts +++ b/tools/ui/src/lib/services/mcp.service.ts @@ -392,7 +392,7 @@ export class MCPService { const url = new URL(config.url); - if (import.meta.env.DEV) { + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log(`[MCPService] Creating WebSocket transport for ${url.href}`); } @@ -413,12 +413,12 @@ export class MCPService { onLog ); - if (useProxy && import.meta.env.DEV) { + if (useProxy && import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`); } try { - if (import.meta.env.DEV) { + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log(`[MCPService] Creating StreamableHTTP transport for ${url.href}`); } @@ -520,7 +520,7 @@ export class MCPService { ) ); - if (import.meta.env.DEV) { + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { console.log(`[MCPService][${serverName}] Creating transport...`); } @@ -560,6 +560,22 @@ export class MCPService { ); const runtimeErrorHandler = (error: Error) => { + // Ignore errors that are expected when the SDK's transport is closed, + // or when connecting to servers that don't support SSE (stateless-only + // endpoints returning 405). The SDK wraps the original AbortError in + // a new Error with the message "SSE stream disconnected: AbortError", + // and also produces "Cannot cancel a stream locked by a reader". + // DOMException is thrown by the browser when aborting fetch requests. + const msg = error.message || String(error); + if ( + error.name === 'AbortError' || + error instanceof DOMException || + msg.includes('SSE stream disconnected') || + msg.includes('stream locked by a reader') || + msg.includes('The operation was aborted') + ) { + return; + } console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error); }; @@ -658,7 +674,10 @@ export class MCPService { this.createLog(MCPConnectionPhase.LISTING_TOOLS, 'Listing available tools...') ); - console.log(`[MCPService][${serverName}] Connected, listing tools...`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + console.log(`[MCPService][${serverName}] Connected, listing tools...`); + } + const tools = await this.listTools({ client, transport, @@ -680,10 +699,11 @@ export class MCPService { `Connection established with ${tools.length} tools (${connectionTimeMs}ms)` ) ); - - console.log( - `[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms` - ); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + console.log( + `[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms` + ); + } return { client, @@ -709,9 +729,22 @@ export class MCPService { * @param connection - The active MCP connection to close */ static async disconnect(connection: MCPConnection): Promise { - console.log(`[MCPService][${connection.serverName}] Disconnecting...`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + console.log(`[MCPService][${connection.serverName}] Disconnecting...`); + } + try { - // Prevent reconnection on voluntary disconnect + // Terminate the session first for streamable-http transports to cleanly + // close streams, matching the inspector's disconnect flow. + if (connection.transport instanceof StreamableHTTPClientTransport) { + await connection.transport.terminateSession(); + } + + // Clear error handlers before closing to prevent noise from expected + // abort errors during shutdown. The inspector avoids this entirely + // by not setting onerror, but since we use it for protocol logging, + // we must clear it before disconnect. + connection.client.onerror = undefined; if (connection.transport.onclose) { connection.transport.onclose = undefined; } @@ -1078,7 +1111,9 @@ export class MCPService { try { await connection.client.unsubscribeResource({ uri }); - console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`); + } } catch (error) { console.error( `[MCPService][${connection.serverName}] Failed to unsubscribe from resource:`, diff --git a/tools/ui/src/lib/services/migration.service.ts b/tools/ui/src/lib/services/migration.service.ts index 5ed24c00d..35d47070a 100644 --- a/tools/ui/src/lib/services/migration.service.ts +++ b/tools/ui/src/lib/services/migration.service.ts @@ -119,7 +119,8 @@ const localStorageMigration: Migration = { // Only migrate if new key doesn't already exist const newValue = localStorage.getItem(newKey); if (newValue !== null) { - console.log(`[Migration] localStorage: ${newKey} already exists, skipping`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] localStorage: ${newKey} already exists, skipping`); continue; } @@ -127,9 +128,11 @@ const localStorageMigration: Migration = { if (oldValue !== null) { localStorage.setItem(newKey, oldValue); // Keep old key for downgrade compatibility - DO NOT DELETE - console.log( - `[Migration] localStorage: copied ${deprecatedKey} → ${newKey} (preserved old)` - ); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + console.log( + `[Migration] localStorage: copied ${deprecatedKey} → ${newKey} (preserved old)` + ); + } } } } @@ -146,7 +149,8 @@ const idxdbMigration: Migration = { async run(): Promise { const oldDbNames = await Dexie.getDatabaseNames(); if (!oldDbNames.includes(DB_APP_NAME_DEPRECATED)) { - console.log('[Migration] IndexedDB: no old database found, skipping'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] IndexedDB: no old database found, skipping'); return; } @@ -155,11 +159,13 @@ const idxdbMigration: Migration = { newDb.version(1).stores(IDXDB_STORES); const existingConvs = await newDb.table(IDXDB_TABLES.conversations).count(); if (existingConvs > 0) { - console.log('[Migration] IndexedDB: new database already has data, skipping'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] IndexedDB: new database already has data, skipping'); return; } - console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED); const oldDb = new Dexie(DB_APP_NAME_DEPRECATED); oldDb.version(1).stores(IDXDB_STORES); @@ -169,15 +175,18 @@ const idxdbMigration: Migration = { if (conversations.length > 0) { await newDb.table(IDXDB_TABLES.conversations).bulkAdd(conversations); - console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`); } if (messages.length > 0) { await newDb.table(IDXDB_TABLES.messages).bulkAdd(messages); - console.log(`[Migration] IndexedDB: copied ${messages.length} messages`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] IndexedDB: copied ${messages.length} messages`); } // Non-destructive: DO NOT delete old database - keep for downgrade compatibility - console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility'); } }; @@ -419,7 +428,8 @@ const legacyMessageMigration: Migration = { } } - console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`); } }; @@ -434,7 +444,8 @@ const themeMigration: Migration = { async run(): Promise { const legacyTheme = localStorage.getItem('theme'); if (legacyTheme === null) { - console.log('[Migration] Theme: no legacy theme key found, skipping'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] Theme: no legacy theme key found, skipping'); return; } @@ -443,7 +454,8 @@ const themeMigration: Migration = { const config = configRaw ? JSON.parse(configRaw) : {}; if (SETTINGS_KEYS.THEME in config) { - console.log('[Migration] Theme: config already has theme, skipping'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] Theme: config already has theme, skipping'); return; } @@ -451,7 +463,8 @@ const themeMigration: Migration = { localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config)); // Non-destructive: DO NOT delete legacy theme key - keep for downgrade compatibility - console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`); } }; @@ -491,7 +504,8 @@ export const MigrationService = { */ resetState(): void { localStorage.removeItem(MIGRATION_STATE_KEY); - console.log('[Migration] State reset - all migrations will run again'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] State reset - all migrations will run again'); }, /** @@ -500,25 +514,30 @@ export const MigrationService = { */ async runAllMigrations(): Promise { const state = getMigrationState(); - console.log('[Migration] Starting migration run, state:', state); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] Starting migration run, state:', state); for (const migration of migrations) { if (isMigrationCompleted(migration.id)) { - console.log(`[Migration] ${migration.id}: already completed, skipping`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] ${migration.id}: already completed, skipping`); continue; } try { - console.log(`[Migration] ${migration.id}: running...`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] ${migration.id}: running...`); await migration.run(); markMigrationCompleted(migration.id); - console.log(`[Migration] ${migration.id}: completed successfully`); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log(`[Migration] ${migration.id}: completed successfully`); } catch (error) { console.error(`[Migration] ${migration.id}: failed`, error); markMigrationFailed(migration.id); } } - console.log('[Migration] All migrations complete'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] All migrations complete'); } }; diff --git a/tools/ui/src/lib/stores/mcp.svelte.ts b/tools/ui/src/lib/stores/mcp.svelte.ts index 8fb306da8..effb78e33 100644 --- a/tools/ui/src/lib/stores/mcp.svelte.ts +++ b/tools/ui/src/lib/stores/mcp.svelte.ts @@ -20,11 +20,11 @@ */ import { browser } from '$app/environment'; -import { base } from '$app/paths'; import { SETTINGS_KEYS } from '$lib/constants'; import { MCPService } from '$lib/services/mcp.service'; import { config, settingsStore } from '$lib/stores/settings.svelte'; import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte'; +import { serverStore } from '$lib/stores/server.svelte'; import { mode } from 'mode-watcher'; import { parseMcpServerSettings, @@ -43,7 +43,6 @@ import { ToolCallType } from '$lib/enums'; import { - CORS_PROXY_ENDPOINT, DEFAULT_CACHE_TTL_MS, DEFAULT_MCP_CONFIG, EXPECTED_THEMED_ICON_PAIR_COUNT, @@ -86,7 +85,6 @@ class MCPStore { private _toolCount = $state(0); private _connectedServers = $state([]); private _healthChecks = $state>({}); - private _proxyAvailable = $state(false); private connections = new Map(); private toolsIndex = new Map(); @@ -96,27 +94,8 @@ class MCPStore { private initPromise: Promise | null = null; private activeFlowCount = 0; - constructor() { - if (browser) { - this.probeProxy(); - } - } - - /** - * Probes the CORS proxy endpoint to determine availability. - * The endpoint is only registered when llama-server runs with --ui-mcp-proxy. - */ - async probeProxy(): Promise { - try { - const response = await fetch(`${base}${CORS_PROXY_ENDPOINT}`, { method: 'HEAD' }); - this._proxyAvailable = response.status !== 404; - } catch { - this._proxyAvailable = false; - } - } - get isProxyAvailable(): boolean { - return this._proxyAvailable; + return serverStore.props?.cors_proxy_enabled ?? false; } /** diff --git a/tools/ui/src/lib/stores/models.svelte.ts b/tools/ui/src/lib/stores/models.svelte.ts index 45981b38f..bc99d7412 100644 --- a/tools/ui/src/lib/stores/models.svelte.ts +++ b/tools/ui/src/lib/stores/models.svelte.ts @@ -3,7 +3,7 @@ import { toast } from 'svelte-sonner'; import { ServerModelStatus, ModelModality } from '$lib/enums'; import { ModelsService } from '$lib/services/models.service'; import { PropsService } from '$lib/services/props.service'; -import { serverStore } from '$lib/stores/server.svelte'; +import { serverStore, isRouterMode } from '$lib/stores/server.svelte'; import { TTLCache } from '$lib/utils'; import { MODEL_PROPS_CACHE_TTL_MS, @@ -14,14 +14,7 @@ import { import { conversationsStore } from '$lib/stores/conversations.svelte'; /** - * modelsStore - Reactive store for model management in both MODEL and ROUTER modes - * - * This store manages: - * - Available models list - * - Selected model for new conversations - * - Loaded models tracking (ROUTER mode) - * - Model usage tracking per conversation - * - Automatic unloading of unused models + * modelsStore - Reactive store for model management in both MODEL and ROUTER modes. * * **Architecture & Relationships:** * - **ModelsService**: Stateless service for model API communication @@ -31,14 +24,8 @@ import { conversationsStore } from '$lib/stores/conversations.svelte'; * * **API Inconsistency Workaround:** * In MODEL mode, `/props` returns modalities for the single model. - * In ROUTER mode, `/props` has no modalities - must use `/props?model=` per model. + * In ROUTER mode, `/props` has no modalities — must use `/props?model=` per model. * This store normalizes this behavior so consumers don't need to know the server mode. - * - * **Key Features:** - * - **MODEL mode**: Single model, always loaded - * - **ROUTER mode**: Multi-model with load/unload capability - * - **Auto-unload**: Automatically unloads models not used by any conversation - * - **Lazy loading**: ensureModelLoaded() loads models on demand */ class ModelsStore { /** @@ -57,8 +44,8 @@ class ModelsStore { selectedModelId = $state(null); selectedModelName = $state(null); - // dedup concurrent fetch() callers, all awaiters share the same inflight promise - // without this, ?model= URL handler raced an in-progress fetch and saw an empty list + // Dedup concurrent fetch() callers — all awaiters share the same inflight promise. + // Without this, ?model= URL handler races an in-progress fetch and sees an empty list. private inflightFetch: Promise | null = null; private modelUsage = $state>>(new Map()); @@ -67,9 +54,9 @@ class ModelsStore { favoriteModelIds = $state>(this.loadFavoritesFromStorage()); /** - * Model-specific props cache with TTL - * Key: modelId, Value: props data including modalities - * TTL: 10 minutes - props don't change frequently + * Model-specific props cache with TTL. + * Key: modelId, Value: props data including modalities. + * TTL: 10 minutes — props don't change frequently. */ private modelPropsCache = new TTLCache({ ttlMs: MODEL_PROPS_CACHE_TTL_MS, @@ -78,7 +65,7 @@ class ModelsStore { private modelPropsFetching = $state>(new Set()); /** - * Version counter for props cache - used to trigger reactivity when props are updated + * Version counter for props cache — used to trigger reactivity when props are updated. */ propsCacheVersion = $state(0); @@ -92,7 +79,7 @@ class ModelsStore { get selectedModel(): ModelOption | null { if (!this.selectedModelId) return null; - return this.models.find((model) => model.id === this.selectedModelId) ?? null; + return this.models.find((m) => m.id === this.selectedModelId) ?? null; } get loadedModelIds(): string[] { @@ -117,7 +104,7 @@ class ModelsStore { * In ROUTER mode, returns null (model is per-conversation). */ get singleModelName(): string | null { - if (serverStore.isRouterMode) return null; + if (isRouterMode()) return null; const props = serverStore.props; if (props?.model_alias) return props.model_alias; @@ -126,6 +113,11 @@ class ModelsStore { return props.model_path.split(/(\\|\/)/).pop() || null; } + get selectedModelContextSize(): number | null { + if (!this.selectedModelName) return null; + return this.getModelContextSize(this.selectedModelName); + } + /** * * @@ -134,10 +126,6 @@ class ModelsStore { * */ - /** - * Get modalities for a specific model - * Returns cached modalities from model props - */ getModelModalities(modelId: string): ModelModalities | null { const model = this.models.find((m) => m.model === modelId || m.id === modelId); if (model?.modalities) { @@ -146,46 +134,29 @@ class ModelsStore { const props = this.modelPropsCache.get(modelId); if (props?.modalities) { - return { - vision: props.modalities.vision ?? false, - audio: props.modalities.audio ?? false, - video: props.modalities.video ?? false - }; + return this.buildModalities(props.modalities); } return null; } - /** - * Check if a model supports vision modality - */ modelSupportsVision(modelId: string): boolean { return this.getModelModalities(modelId)?.vision ?? false; } - /** - * Check if a model supports audio modality - */ modelSupportsAudio(modelId: string): boolean { return this.getModelModalities(modelId)?.audio ?? false; } - /** - * Check if a model supports video modality - */ modelSupportsVideo(modelId: string): boolean { return this.getModelModalities(modelId)?.video ?? false; } - /** - * Get model modalities as an array of ModelModality enum values - */ getModelModalitiesArray(modelId: string): ModelModality[] { const modalities = this.getModelModalities(modelId); if (!modalities) return []; const result: ModelModality[] = []; - if (modalities.vision) result.push(ModelModality.VISION); if (modalities.audio) result.push(ModelModality.AUDIO); if (modalities.video) result.push(ModelModality.VIDEO); @@ -193,16 +164,10 @@ class ModelsStore { return result; } - /** - * Get props for a specific model (from cache) - */ getModelProps(modelId: string): ApiLlamaCppServerProps | null { return this.modelPropsCache.get(modelId); } - /** - * Get context size (n_ctx) for a specific model from cached props - */ getModelContextSize(modelId: string): number | null { const props = this.getModelProps(modelId); const nCtx = props?.default_generation_settings?.n_ctx; @@ -210,17 +175,6 @@ class ModelsStore { return typeof nCtx === 'number' ? nCtx : null; } - /** - * Get context size for the currently selected model or null if no model is selected - */ - get selectedModelContextSize(): number | null { - if (!this.selectedModelName) return null; - return this.getModelContextSize(this.selectedModelName); - } - - /** - * Check if props are being fetched for a model - */ isModelPropsFetching(modelId: string): boolean { return this.modelPropsFetching.has(modelId); } @@ -235,10 +189,10 @@ class ModelsStore { isModelLoaded(modelId: string): boolean { const model = this.routerModels.find((m) => m.id === modelId); + return ( model?.status.value === ServerModelStatus.LOADED || - model?.status.value === ServerModelStatus.SLEEPING || - false + model?.status.value === ServerModelStatus.SLEEPING ); } @@ -248,6 +202,7 @@ class ModelsStore { getModelStatus(modelId: string): ServerModelStatus | null { const model = this.routerModels.find((m) => m.id === modelId); + return model?.status.value ?? null; } @@ -257,6 +212,7 @@ class ModelsStore { isModelInUse(modelId: string): boolean { const usage = this.modelUsage.get(modelId); + return usage !== undefined && usage.size > 0; } @@ -269,8 +225,8 @@ class ModelsStore { */ /** - * Fetch list of models from server and detect server role - * Also fetches modalities for MODEL mode (single model) + * Fetch list of models from server and detect server role. + * Also fetches modalities for MODEL mode (single model). */ async fetch(force = false): Promise { if (this.inflightFetch) return this.inflightFetch; @@ -293,69 +249,87 @@ class ModelsStore { await serverStore.fetch(); } - const response = await ModelsService.list(); + const router = isRouterMode(); - const models: ModelOption[] = response.data.map((item: ApiModelDataEntry, index: number) => { - const details = response.models?.[index]; - const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : []; - const displayNameSource = - details?.name && details.name.trim().length > 0 ? details.name : item.id; - const displayName = this.toDisplayName(displayNameSource); - const modelId = details?.model || item.id; + if (router) { + const response = await ModelsService.listRouter(); - return { - id: item.id, - name: displayName, - model: modelId, - description: details?.description, - capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)), - details: details?.details, - meta: item.meta ?? null, - parsedId: ModelsService.parseModelId(modelId), - aliases: item.aliases ?? [], - tags: item.tags ?? [] - } satisfies ModelOption; - }); + this.routerModels = response.data; + this.models = this.buildModelOptions(response); - this.models = models; + await this.fetchModalitiesForLoadedModels(); - // WORKAROUND: In MODEL mode, /props returns modalities for the single model, - // but /v1/models doesn't include modalities. We bridge this gap here. - const serverProps = serverStore.props; - if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) { - const modalities: ModelModalities = { - vision: serverProps.modalities.vision ?? false, - audio: serverProps.modalities.audio ?? false, - video: serverProps.modalities.video ?? false - }; - this.modelPropsCache.set(this.models[0].model, serverProps); - this.models = this.models.map((model, index) => - index === 0 ? { ...model, modalities } : model - ); + const visible = this.getVisibleModels(); + + if (visible.length === 1 && this.isModelLoaded(visible[0].model)) { + this.selectModelById(visible[0].id); + } + } else { + this.models = await this.fetchModelModeInternal(); } } catch (error) { this.models = []; this.error = error instanceof Error ? error.message : 'Failed to load models'; + throw error; } finally { this.loading = false; } } + /** Fetch models in MODEL mode (single model, standard OpenAI-compatible). */ + private async fetchModelModeInternal(): Promise { + const response = await ModelsService.list(); + + return this.buildModelOptions(response); + } + /** - * Fetch router models with full metadata (ROUTER mode only) - * This fetches the /models endpoint which returns status info for each model + * Build ModelOption[] from an API response. + * Both MODEL and ROUTER modes share the same mapping logic; + * they differ only in which endpoint is called. + */ + private buildModelOptions( + response: ApiModelListResponse | ApiRouterModelsListResponse + ): ModelOption[] { + return response.data.map((item: ApiModelDataEntry, index: number) => { + const details = response.models?.[index]; + const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : []; + const displayNameSource = + details?.name && details.name.trim().length > 0 ? details.name : item.id; + const modelId = details?.model || item.id; + + return { + id: item.id, + name: this.toDisplayName(displayNameSource), + model: modelId, + description: details?.description, + capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)), + details: details?.details, + meta: item.meta ?? null, + parsedId: ModelsService.parseModelId(modelId), + aliases: item.aliases ?? [], + tags: item.tags ?? [] + }; + }); + } + + /** + * Fetch router models with full metadata (ROUTER mode only). + * No-op in router mode — fetch() already calls listRouter() internally. + * Kept for API compatibility (e.g. handleOpenChange dropdown open handler). */ async fetchRouterModels(): Promise { + if (!isRouterMode()) return; + try { const response = await ModelsService.listRouter(); this.routerModels = response.data; await this.fetchModalitiesForLoadedModels(); - const o = this.models.filter((option) => this.getModelProps(option.model)?.ui !== false); - - if (o.length === 1 && this.isModelLoaded(o[0].model)) { - this.selectModelById(o[0].id); + const visible = this.getVisibleModels(); + if (visible.length === 1 && this.isModelLoaded(visible[0].model)) { + this.selectModelById(visible[0].id); } } catch (error) { console.warn('Failed to fetch router models:', error); @@ -364,10 +338,10 @@ class ModelsStore { } /** - * Fetch props for a specific model from /props endpoint - * Uses caching to avoid redundant requests + * Fetch props for a specific model from /props endpoint. + * Uses caching to avoid redundant requests. * - * In ROUTER mode, this will only fetch props if the model is loaded, + * In ROUTER mode, this only fetches props if the model is loaded, * since unloaded models return 400 from /props endpoint. * * @param modelId - Model identifier to fetch props for @@ -397,10 +371,7 @@ class ModelsStore { } } - /** - * Fetch modalities for all loaded models from /props endpoint - * This updates the modalities field in models array - */ + /** Fetch modalities for all loaded models from /props endpoint. */ async fetchModalitiesForLoadedModels(): Promise { const loadedModelIds = this.loadedModelIds; if (loadedModelIds.length === 0) return; @@ -410,7 +381,6 @@ class ModelsStore { try { const results = await Promise.all(propsPromises); - // Update models with modalities this.models = this.models.map((model) => { const modelIndex = loadedModelIds.indexOf(model.model); if (modelIndex === -1) return model; @@ -418,13 +388,7 @@ class ModelsStore { const props = results[modelIndex]; if (!props?.modalities) return model; - const modalities: ModelModalities = { - vision: props.modalities.vision ?? false, - audio: props.modalities.audio ?? false, - video: props.modalities.video ?? false - }; - - return { ...model, modalities }; + return { ...model, modalities: this.buildModalities(props.modalities) }; }); this.propsCacheVersion++; @@ -433,17 +397,38 @@ class ModelsStore { } } + /** + * Update modalities for a specific model. + * Called when a model is loaded or when we need fresh modality data. + */ + async updateModelModalities(modelId: string): Promise { + const props = await this.fetchModelProps(modelId); + if (!props?.modalities) return; + + this.models = this.models.map((model) => + model.model === modelId + ? { ...model, modalities: this.buildModalities(props.modalities!) } + : model + ); + + this.propsCacheVersion++; + } + + /** + * Filter to models visible in the UI (ui !== false). + */ + private getVisibleModels(): ModelOption[] { + return this.models.filter((option) => this.getModelProps(option.model)?.ui !== false); + } + /** * Gets the model name from the last assistant message in the active conversation. - * Iterates backward through messages to find the most recent message with a model. * Used by both the chat page and settings page to maintain model consistency. - * @returns The model name or null if not found */ getModelFromLastAssistantResponse(): string | null { const messages = conversationsStore.activeMessages; if (!messages || messages.length === 0) return null; - // Iterate backward to find the last message with a model for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].model) { return messages[i].model; @@ -456,22 +441,13 @@ class ModelsStore { /** * Auto-selects the model from the last assistant response if available and loaded. * Returns true if a model was selected, false otherwise. - * This is used by the chat page to maintain model consistency across page navigation. */ async selectModelFromLastAssistantResponse(): Promise { const lastModel = this.getModelFromLastAssistantResponse(); - if (!lastModel) return false; - - // Skip if already selected - if (this.selectedModelName === lastModel) return false; + if (!lastModel || this.selectedModelName === lastModel) return false; const matchingModel = this.models.find((option) => option.model === lastModel); - if (!matchingModel) return false; - - if (!this.isModelLoaded(lastModel)) { - console.log('[modelsStore] last assistant model not loaded:', lastModel); - return false; - } + if (!matchingModel || !this.isModelLoaded(lastModel)) return false; try { await this.selectModelById(matchingModel.id); @@ -484,22 +460,17 @@ class ModelsStore { } /** - * Auto-selects the first available model if none is selected, and fetches its props. + * Auto-selects the first available model if none is selected. * Prioritizes: * 1. Model from active conversation's last assistant response (if loaded) * 2. Model from active conversation's last assistant response (if not loaded) * 3. First loaded model (not from active conversation) * 4. First available model - * This is used to ensure default values are populated in settings pages. */ async ensureFirstModelSelected(): Promise { if (this.selectedModelName) return; - // Filter models that are visible in the UI - const availableModels = this.models.filter( - (option) => this.getModelProps(option.model)?.ui !== false - ); - + const availableModels = this.getVisibleModels(); if (availableModels.length === 0) return; // Try to select model from last assistant response first @@ -515,7 +486,7 @@ class ModelsStore { } } - // Try to find a loaded model first + // Try a loaded model first const loadedModel = availableModels.find((m) => this.isModelLoaded(m.model)); if (loadedModel) { await this.selectModelById(loadedModel.id); @@ -524,34 +495,7 @@ class ModelsStore { } // Fall back to the first available model - const firstModel = availableModels[0]; - await this.selectModelById(firstModel.id); - // Don't fetch props for unloaded models (will fail in ROUTER mode) - } - - /** - * Update modalities for a specific model - * Called when a model is loaded or when we need fresh modality data - */ - async updateModelModalities(modelId: string): Promise { - try { - const props = await this.fetchModelProps(modelId); - if (!props?.modalities) return; - - const modalities: ModelModalities = { - vision: props.modalities.vision ?? false, - audio: props.modalities.audio ?? false, - video: props.modalities.video ?? false - }; - - this.models = this.models.map((model) => - model.model === modelId ? { ...model, modalities } : model - ); - - this.propsCacheVersion++; - } catch (error) { - console.warn(`Failed to update modalities for model ${modelId}:`, error); - } + await this.selectModelById(availableModels[0].id); } /** @@ -562,9 +506,6 @@ class ModelsStore { * */ - /** - * Select a model for new conversations - */ async selectModelById(modelId: string): Promise { if (!modelId || this.updating) return; if (this.selectedModelId === modelId) return; @@ -584,8 +525,7 @@ class ModelsStore { } /** - * Select a model by its model name (used for syncing with conversation model) - * @param modelName - Model name to select (e.g., "ggml-org/GLM-4.7-Flash-GGUF") + * Select a model by its model name (used for syncing with conversation model). */ selectModelByName(modelName: string): void { const option = this.models.find((model) => model.model === modelName); @@ -615,7 +555,7 @@ class ModelsStore { /** * * - * Loading/Unloading Models + * Loading / Unloading Models * * */ @@ -623,27 +563,18 @@ class ModelsStore { /** * WORKAROUND: Polling for model status after load/unload operations. * - * Currently, the `/models/load` and `/models/unload` endpoints return success - * before the operation actually completes on the server. This means an immediate - * request to `/models` returns stale status (e.g., "loading" after load request, - * "loaded" after unload request). + * Currently, `/models/load` and `/models/unload` return success before + * the operation actually completes on the server. * - * TODO: Remove this polling once llama-server properly waits for the operation - * to complete before returning success from `/load` and `/unload` endpoints. - * At that point, a single `fetchRouterModels()` call after the operation will - * be sufficient to get the correct status. + * TODO: Remove polling once llama-server properly waits for the operation + * to complete before returning success. */ - /** Polling interval in ms for checking model status */ private static readonly STATUS_POLL_INTERVAL = 500; /** * Poll for expected model status after load/unload operation. - * Keeps polling indefinitely until the model reaches the expected status or fails. - * - * @param modelId - Model identifier to check - * @param expectedStatus - Expected status to wait for - * @throws Error if model reaches FAILED status + * Keeps polling until the model reaches the expected status or fails. */ private async pollForModelStatus( modelId: string, @@ -654,9 +585,7 @@ class ModelsStore { await this.fetchRouterModels(); const currentStatus = this.getModelStatus(modelId); - if (currentStatus === expectedStatus) { - return; - } + if (currentStatus === expectedStatus) return; if (currentStatus === ServerModelStatus.FAILED) { throw new Error( @@ -677,15 +606,8 @@ class ModelsStore { } } - /** - * Load a model (ROUTER mode) - * @param modelId - Model identifier to load - */ async loadModel(modelId: string): Promise { - if (this.isModelLoaded(modelId)) { - return; - } - + if (this.isModelLoaded(modelId)) return; if (this.modelLoadingStates.get(modelId)) return; this.modelLoadingStates.set(modelId, true); @@ -694,7 +616,6 @@ class ModelsStore { try { await ModelsService.load(modelId); await this.pollForModelStatus(modelId, ServerModelStatus.LOADED); - await this.updateModelModalities(modelId); toast.success(`Model loaded: ${this.toDisplayName(modelId)}`); } catch (error) { @@ -706,15 +627,8 @@ class ModelsStore { } } - /** - * Unload a model (ROUTER mode) - * @param modelId - Model identifier to unload - */ async unloadModel(modelId: string): Promise { - if (!this.isModelLoaded(modelId)) { - return; - } - + if (!this.isModelLoaded(modelId)) return; if (this.modelLoadingStates.get(modelId)) return; this.modelLoadingStates.set(modelId, true); @@ -722,7 +636,6 @@ class ModelsStore { try { await ModelsService.unload(modelId); - await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED); toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`); } catch (error) { @@ -734,15 +647,8 @@ class ModelsStore { } } - /** - * Ensure a model is loaded before use - * @param modelId - Model identifier to ensure is loaded - */ async ensureModelLoaded(modelId: string): Promise { - if (this.isModelLoaded(modelId)) { - return; - } - + if (this.isModelLoaded(modelId)) return; await this.loadModel(modelId); } @@ -779,11 +685,9 @@ class ModelsStore { private loadFavoritesFromStorage(): Set { try { const raw = localStorage.getItem(FAVORITE_MODELS_LOCALSTORAGE_KEY); - return raw ? new Set(JSON.parse(raw) as string[]) : new Set(); } catch { toast.error('Failed to load favorite models from local storage'); - return new Set(); } } @@ -799,10 +703,19 @@ class ModelsStore { private toDisplayName(id: string): string { const segments = id.split(/\\|\//); const candidate = segments.pop(); - return candidate && candidate.trim().length > 0 ? candidate : id; } + private buildModalities( + modalities: NonNullable + ): ModelModalities { + return { + vision: modalities.vision ?? false, + audio: modalities.audio ?? false, + video: modalities.video ?? false + }; + } + clear(): void { this.models = []; this.routerModels = []; diff --git a/tools/ui/src/lib/types/api.d.ts b/tools/ui/src/lib/types/api.d.ts index 316ad5528..5f0a38dd3 100644 --- a/tools/ui/src/lib/types/api.d.ts +++ b/tools/ui/src/lib/types/api.d.ts @@ -203,6 +203,7 @@ export interface ApiLlamaCppServerProps { /** @deprecated Use {@link ui_settings} instead */ webui_settings?: Record; ui_settings?: Record; + cors_proxy_enabled?: boolean; } export interface ApiChatCompletionRequest { diff --git a/tools/ui/src/lib/types/mcp.d.ts b/tools/ui/src/lib/types/mcp.d.ts index 7aa050cdf..2a2926142 100644 --- a/tools/ui/src/lib/types/mcp.d.ts +++ b/tools/ui/src/lib/types/mcp.d.ts @@ -1,5 +1,5 @@ -import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp'; -import type { ToolSource } from '$lib/enums/tools'; +import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp.enums'; +import type { ToolSource } from '$lib/enums/tools.enums'; import type { Client, ClientCapabilities as SDKClientCapabilities, diff --git a/tools/ui/src/lib/utils/api-key-validation.ts b/tools/ui/src/lib/utils/api-key-validation.ts index 948b7d7b6..dbbf9a09b 100644 --- a/tools/ui/src/lib/utils/api-key-validation.ts +++ b/tools/ui/src/lib/utils/api-key-validation.ts @@ -12,17 +12,21 @@ export async function validateApiKey(fetch: typeof globalThis.fetch): Promise = { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}` }; - if (apiKey) { - headers.Authorization = `Bearer ${apiKey}`; - } - const response = await fetch(`${base}/props`, { headers }); if (!response.ok) { diff --git a/tools/ui/src/lib/utils/legacy-migration.ts b/tools/ui/src/lib/utils/legacy-migration.ts index 19755f6ee..6b0890a36 100644 --- a/tools/ui/src/lib/utils/legacy-migration.ts +++ b/tools/ui/src/lib/utils/legacy-migration.ts @@ -333,7 +333,8 @@ async function migrateConversation(convId: string): Promise { export async function runLegacyMigration(): Promise { if (!isMigrationNeeded()) return; - console.log('[Migration] Starting legacy message format migration...'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) + console.log('[Migration] Starting legacy message format migration...'); try { const conversations = await DatabaseService.getAllConversations(); @@ -344,12 +345,14 @@ export async function runLegacyMigration(): Promise { totalMigrated += count; } - if (totalMigrated > 0) { - console.log( - `[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations` - ); - } else { - console.log('[Migration] No legacy messages found, marking as done'); + if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) { + if (totalMigrated > 0) { + console.log( + `[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations` + ); + } else { + console.log('[Migration] No legacy messages found, marking as done'); + } } markMigrationDone(); diff --git a/tools/ui/src/routes/(chat)/+page.svelte b/tools/ui/src/routes/(chat)/+page.svelte index c272b438e..9db1d445f 100644 --- a/tools/ui/src/routes/(chat)/+page.svelte +++ b/tools/ui/src/routes/(chat)/+page.svelte @@ -3,7 +3,6 @@ import { chatStore } from '$lib/stores/chat.svelte'; import { conversationsStore, isConversationsInitialized } from '$lib/stores/conversations.svelte'; import { modelsStore, modelOptions } from '$lib/stores/models.svelte'; - import { isRouterMode } from '$lib/stores/server.svelte'; import { onMount } from 'svelte'; import { page } from '$app/state'; import { replaceState } from '$app/navigation'; @@ -72,23 +71,13 @@ conversationsStore.clearActiveConversation(); chatStore.clearUIState(); - if ( - isRouterMode() && - modelsStore.selectedModelName && - !modelsStore.isModelLoaded(modelsStore.selectedModelName) - ) { - modelsStore.clearSelection(); + await modelsStore.fetch(); - const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model)); - if (first) { - await modelsStore.selectModelById(first.id); - } - } - - // Handle URL params only if we have ?q= or ?model= or ?new_chat=true if (qParam !== null || modelParam !== null || newChatParam === 'true') { await handleUrlParams(); } + + await modelsStore.ensureFirstModelSelected(); }); diff --git a/tools/ui/src/routes/+layout.svelte b/tools/ui/src/routes/+layout.svelte index e03d13fef..0610b07ae 100644 --- a/tools/ui/src/routes/+layout.svelte +++ b/tools/ui/src/routes/+layout.svelte @@ -7,11 +7,13 @@ import { untrack } from 'svelte'; import { onMount } from 'svelte'; import { fade } from 'svelte/transition'; + import { DesktopIconStrip, DialogConversationTitleUpdate, SidebarNavigation } from '$lib/components/app'; + import { conversationsStore } from '$lib/stores/conversations.svelte'; import * as Sidebar from '$lib/components/ui/sidebar/index.js'; import * as Tooltip from '$lib/components/ui/tooltip'; @@ -30,26 +32,29 @@ import { conversations } from '$lib/stores/conversations.svelte'; let { children } = $props(); - let alwaysShowSidebarOnDesktop = $derived(config().alwaysShowSidebarOnDesktop); let isMobile = new IsMobile(); let isDesktop = $derived(!isMobile.current); let sidebarOpen = $state(false); let mounted = $state(false); let innerHeight = $state(); + let chatSidebar: - | { activateSearchMode?: () => void; editActiveConversation?: () => void } + | { + activateSearchMode?: () => void; + editActiveConversation?: () => void; + } | undefined = $state(); let titleUpdateDialogOpen = $state(false); let titleUpdateCurrentTitle = $state(''); let titleUpdateNewTitle = $state(''); let titleUpdateResolve: ((value: boolean) => void) | null = null; - const panelNav = useSettingsNavigation(); function navigateToConversation(direction: -1 | 1) { const allConvs = conversations(); + if (allConvs.length === 0) return; const currentId = page.params.id; @@ -61,6 +66,7 @@ } const idx = allConvs.findIndex((c) => c.id === currentId); + if (idx === -1) return; const targetIdx = idx + direction; @@ -75,38 +81,41 @@ // Global keyboard shortcuts const { handleKeydown } = useKeyboardShortcuts({ editActiveConversation: () => chatSidebar?.editActiveConversation?.(), - navigateToPrevConversation: () => navigateToConversation(-1), - navigateToNextConversation: () => navigateToConversation(1) }); function checkApiKey() { const apiKey = config().apiKey; - if ( - (page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') && - page.status !== 401 && - page.status !== 403 - ) { - const headers: Record = { - 'Content-Type': 'application/json' - }; - - if (apiKey && apiKey.trim() !== '') { - headers.Authorization = `Bearer ${apiKey.trim()}`; - } - - fetch(`${base}/props`, { headers }) - .then((response) => { - if (response.status === 401 || response.status === 403) { - window.location.reload(); - } - }) - .catch((e) => { - console.error('Error checking API key:', e); - }); + // No API key configured — server doesn't require auth, no need to validate. + // This mirrors the early return in validateApiKey() to avoid redundant /props requests. + if (!apiKey || apiKey.trim() === '') { + return; } + + untrack(() => { + if ( + (page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') && + page.status !== 401 && + page.status !== 403 + ) { + const headers: Record = { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey.trim()}` + }; + + fetch(`${base}/props`, { headers }) + .then((response) => { + if (response.status === 401 || response.status === 403) { + window.location.reload(); + } + }) + .catch((e) => { + console.error('Error checking API key:', e); + }); + } + }); } function handleTitleUpdateCancel() { @@ -134,6 +143,7 @@ $effect(() => { if (alwaysShowSidebarOnDesktop && isDesktop) { sidebarOpen = true; + return; } }); @@ -170,6 +180,7 @@ // Only fetch router models once when we have models loaded and in router mode if (isRouter && modelsCount > 0 && !routerModelsFetched) { routerModelsFetched = true; + untrack(() => { modelsStore.fetchRouterModels(); }); @@ -218,7 +229,6 @@ - -
- - - +
+ {#if !(alwaysShowSidebarOnDesktop && isDesktop) && !(panelNav.isSettingsRoute && !isDesktop)} {#if mounted} @@ -261,9 +271,9 @@ /> {/if} - - {@render children?.()} - + {@render children?.()}
diff --git a/tools/ui/src/styles/katex-custom.scss b/tools/ui/src/styles/katex-custom.scss index 9c8b96ed5..0e385844a 100644 --- a/tools/ui/src/styles/katex-custom.scss +++ b/tools/ui/src/styles/katex-custom.scss @@ -8,6 +8,9 @@ $use-ttf: false; $font-folder: 'katex-fonts'; // Import KaTeX SCSS with overridden variables -// Note: @import is deprecated but required because KaTeX uses @import internally -// The deprecation warnings are from KaTeX's code and cannot be avoided -@import 'katex/src/styles/katex.scss'; +@use 'katex/src/styles/katex.scss' with ( + $use-woff2: true, + $use-woff: false, + $use-ttf: false, + $font-folder: 'katex-fonts' +); diff --git a/tools/ui/tests/client/page.svelte.test.ts b/tools/ui/tests/client/page.svelte.test.ts index 6849beb27..32e333d7f 100644 --- a/tools/ui/tests/client/page.svelte.test.ts +++ b/tools/ui/tests/client/page.svelte.test.ts @@ -4,8 +4,9 @@ import TestWrapper from './components/TestWrapper.svelte'; describe('/+page.svelte', () => { it('should render page without throwing', async () => { - // Basic smoke test - page should render without throwing errors - // API calls will fail in test environment but component should still mount - expect(() => render(TestWrapper)).not.toThrow(); + // Basic smoke test - page should render without throwing errors. + // API calls are mocked in vitest-setup-client.ts. + await render(TestWrapper); + expect(true).toBe(true); }); }); diff --git a/tools/ui/vite.config.ts b/tools/ui/vite.config.ts index d3db24bf2..f89a689d5 100644 --- a/tools/ui/vite.config.ts +++ b/tools/ui/vite.config.ts @@ -23,18 +23,6 @@ export default defineConfig({ minify: true }, - css: { - preprocessorOptions: { - scss: { - additionalData: ` - $use-woff2: true; - $use-woff: false; - $use-ttf: false; - ` - } - } - }, - plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()], test: { diff --git a/tools/ui/vitest-setup-client.ts b/tools/ui/vitest-setup-client.ts index 570b9f0e1..90994442e 100644 --- a/tools/ui/vitest-setup-client.ts +++ b/tools/ui/vitest-setup-client.ts @@ -1,2 +1,80 @@ /// /// + +import { beforeEach, vi } from 'vitest'; + +// Mock fetch for API calls during client tests. +// In test environment there is no backend server, so we intercept +// the specific endpoints the app uses and return valid mock data. +beforeEach(() => { + const originalFetch = globalThis.fetch; + + vi.spyOn(globalThis, 'fetch').mockImplementation( + async (input: RequestInfo | URL, init?: RequestInit) => { + const url = typeof input === 'string' ? input : input instanceof URL ? input.href : input.url; + + // Mock server props endpoint + if (url.includes('/server')) { + return new Response( + JSON.stringify({ + mode: 'router', + version: 'test', + git_commit: 'test', + git_branch: 'test' + }), + { status: 200, headers: { 'Content-Type': 'application/json' } } + ); + } + + // Mock models list endpoint + if (/\/v1\/models|\/models\b/.test(url)) { + return new Response( + JSON.stringify({ + object: 'list', + data: [ + { + id: 'test-model.gguf', + object: 'model', + owned_by: 'llamacpp', + created: 0, + in_cache: false, + path: 'models/test-model.gguf', + status: { value: 'unloaded' }, + meta: {} + } + ], + models: [ + { + model: 'test-model.gguf', + name: 'Test Model', + details: {} + } + ] + }), + { status: 200, headers: { 'Content-Type': 'application/json' } } + ); + } + + // Mock /props endpoint (used for modalities) + if (url.includes('/props')) { + return new Response( + JSON.stringify({ + default_generation_settings: { n_ctx: 2048 } + }), + { status: 200, headers: { 'Content-Type': 'application/json' } } + ); + } + + // Mock /tools endpoint (used for built-in tools list) + if (url.includes('/tools')) { + return new Response(JSON.stringify([]), { + status: 200, + headers: { 'Content-Type': 'application/json' } + }); + } + + // Default: use real fetch + return originalFetch(input, init); + } + ); +}); diff --git a/tools/ui/vitest.shims.d.ts b/tools/ui/vitest.shims.d.ts new file mode 100644 index 000000000..03b1801a6 --- /dev/null +++ b/tools/ui/vitest.shims.d.ts @@ -0,0 +1 @@ +///