mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/cann.Dockerfile # .devops/cpu.Dockerfile # .devops/cuda.Dockerfile # .devops/intel.Dockerfile # .devops/llama-cli-cann.Dockerfile # .devops/musa.Dockerfile # .devops/openvino.Dockerfile # .devops/rocm.Dockerfile # .devops/s390x.Dockerfile # .devops/vulkan.Dockerfile # .github/ISSUE_TEMPLATE/011-bug-results.yml # .github/ISSUE_TEMPLATE/019-bug-misc.yml # .github/workflows/build-and-test-snapdragon.yml # .github/workflows/docker.yml # .github/workflows/server-self-hosted.yml # .github/workflows/ui-ci.yml # .pi/gg/SYSTEM.md # README.md # common/arg.cpp # docs/backend/SYCL.md # docs/backend/snapdragon/CMakeUserPresets.json # docs/backend/snapdragon/README.md # docs/speculative.md # examples/save-load-state/save-load-state.cpp # ggml/src/ggml-hexagon/ggml-hexagon.cpp # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-ops.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-hexagon/htp/rope-ops.c # ggml/src/ggml-hexagon/htp/unary-ops.c # ggml/src/ggml-opencl/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/cvt.cl # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl # tools/cli/README.md # tools/server/README.md
This commit is contained in:
commit
7d987af23a
77 changed files with 1291 additions and 1048 deletions
|
|
@ -4,7 +4,6 @@
|
|||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "hf-cache.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
|
|
@ -539,7 +538,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
const bool skip = (arg == "--spec-type");
|
||||
|
||||
if (!skip) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
}
|
||||
auto & tmp = arg_to_options[arg];
|
||||
auto opt = *tmp.first;
|
||||
|
|
@ -588,12 +591,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// TODO: Remove later // KCPP: remove for now
|
||||
// try {
|
||||
// hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
|
||||
// } catch (const std::exception & e) {
|
||||
// LOG_WRN("HF cache migration failed: %s\n", e.what());
|
||||
// }
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
|
|
@ -902,7 +899,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
|||
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
|
||||
}
|
||||
if (!seen_args.insert(arg).second) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
const bool skip = (arg == "--spec-type");
|
||||
|
||||
if (!skip) {
|
||||
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
|
||||
}
|
||||
}
|
||||
auto opt = *arg_to_options[arg];
|
||||
std::string val;
|
||||
|
|
@ -3365,7 +3366,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
" - 1: error\n"
|
||||
" - 2: warning\n"
|
||||
" - 3: info\n"
|
||||
" - 4: debug\n"
|
||||
" - 4: trace (more info)\n"
|
||||
" - 5: debug\n"
|
||||
"(default: %d)\n", params.verbosity),
|
||||
[](common_params & params, int value) {
|
||||
params.verbosity = value;
|
||||
|
|
@ -4126,6 +4128,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.speculative.ngram_mod.n_match = 24;
|
||||
params.speculative.ngram_mod.n_min = 48;
|
||||
params.speculative.ngram_mod.n_max = 64;
|
||||
|
||||
// TODO: not sure if this is a good config - explore more settings and potentially enable it
|
||||
//params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
|
||||
//params.speculative.ngram_map_k4v.size_n = 8;
|
||||
//params.speculative.ngram_map_k4v.size_m = 24;
|
||||
//params.speculative.ngram_map_k4v.min_hits = 2;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
|
||||
|
|
|
|||
|
|
@ -1166,7 +1166,7 @@ struct common_init_result::impl {
|
|||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
common_init_result::common_init_result(common_params & params, bool model_only) :
|
||||
pimpl(new impl{}) {
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
|
@ -1179,7 +1179,7 @@ common_init_result::common_init_result(common_params & params) :
|
|||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
|
|
@ -1189,6 +1189,10 @@ common_init_result::common_init_result(common_params & params) :
|
|||
|
||||
pimpl->model.reset(model);
|
||||
|
||||
if (model_only) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
|
|
@ -1258,29 +1262,6 @@ common_init_result::common_init_result(common_params & params) :
|
|||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
}
|
||||
|
||||
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
|
||||
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
|
||||
// currently this is not supported. so we disable the partial rollback
|
||||
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
|
||||
auto & types = params.speculative.types;
|
||||
|
||||
for (int i = 0; i < (int) types.size(); i++) {
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
|
||||
continue;
|
||||
}
|
||||
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cparams.n_rs_seq = 0;
|
||||
|
||||
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
|
||||
common_speculative_type_to_str(types[i]).c_str());
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
|
|
@ -1315,8 +1296,8 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
|||
return pimpl->lora;
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
common_init_result_ptr common_init_from_params(common_params & params, bool model_only) {
|
||||
common_init_result_ptr res(new common_init_result(params, model_only));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
|
|
@ -1324,6 +1305,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
|||
return res;
|
||||
}
|
||||
|
||||
if (model_only) {
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
|
|
@ -1387,7 +1372,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
|||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
llama_set_warmup(lctx, true);
|
||||
|
||||
|
|
|
|||
|
|
@ -300,11 +300,11 @@ struct common_params_model {
|
|||
|
||||
// draft-model-based speculative decoding parameters
|
||||
struct common_params_speculative_draft {
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.0f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
common_params_model mparams;
|
||||
|
||||
|
|
@ -858,7 +858,7 @@ struct common_sampler;
|
|||
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
common_init_result(common_params & params, bool model_only = false);
|
||||
~common_init_result();
|
||||
|
||||
llama_model * model();
|
||||
|
|
@ -876,7 +876,7 @@ private:
|
|||
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
common_init_result_ptr common_init_from_params(common_params & params, bool model_only = false);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <atomic>
|
||||
#include <regex> // migration only
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <stdexcept>
|
||||
|
|
@ -336,15 +335,9 @@ hf_files get_repo_files(const std::string & repo_id,
|
|||
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
|
||||
file.oid = item["lfs"]["oid"].get<std::string>();
|
||||
}
|
||||
if (item["lfs"].contains("size") && item["lfs"]["size"].is_number()) {
|
||||
file.size = item["lfs"]["size"].get<size_t>();
|
||||
}
|
||||
} else if (item.contains("oid") && item["oid"].is_string()) {
|
||||
file.oid = item["oid"].get<std::string>();
|
||||
}
|
||||
if (file.size == 0 && item.contains("size") && item["size"].is_number()) {
|
||||
file.size = item["size"].get<size_t>();
|
||||
}
|
||||
|
||||
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
|
||||
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
|
||||
|
|
@ -502,271 +495,4 @@ std::string finalize_file(const hf_file & file) {
|
|||
return file.final_path;
|
||||
}
|
||||
|
||||
// delete everything after this line, one day
|
||||
|
||||
// copied from download.cpp without the tag part
|
||||
struct gguf_split_info {
|
||||
std::string prefix; // tag included
|
||||
int index;
|
||||
int count;
|
||||
};
|
||||
|
||||
static gguf_split_info get_gguf_split_info(const std::string & path) {
|
||||
static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
|
||||
std::smatch m;
|
||||
|
||||
std::string prefix = path;
|
||||
if (!string_remove_suffix(prefix, ".gguf")) {
|
||||
return {};
|
||||
}
|
||||
|
||||
int index = 1;
|
||||
int count = 1;
|
||||
|
||||
if (std::regex_match(prefix, m, re_split)) {
|
||||
index = std::stoi(m[2].str());
|
||||
count = std::stoi(m[3].str());
|
||||
prefix = m[1].str();
|
||||
}
|
||||
|
||||
return {std::move(prefix), index, count};
|
||||
}
|
||||
|
||||
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
|
||||
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
|
||||
std::smatch match;
|
||||
if (std::regex_match(filename, match, re)) {
|
||||
return {match[1].str(), match[2].str()};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::string make_old_cache_filename(const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & filename) {
|
||||
auto result = owner + "_" + repo + "_" + filename;
|
||||
string_replace_all(result, "/", "_");
|
||||
return result;
|
||||
}
|
||||
|
||||
struct migrate_file {
|
||||
std::string path;
|
||||
std::string sha256;
|
||||
size_t size;
|
||||
fs::path old_path;
|
||||
fs::path etag_path;
|
||||
const hf_file * file;
|
||||
};
|
||||
|
||||
using migrate_files = std::vector<migrate_file>;
|
||||
|
||||
static bool collect_file(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const std::string & path,
|
||||
const std::string & sha256,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
const hf_file * file = nullptr;
|
||||
|
||||
for (const auto & f : files) {
|
||||
if (f.path == path) {
|
||||
file = &f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string old_filename = make_old_cache_filename(owner, repo, path);
|
||||
fs::path old_path = old_cache / old_filename;
|
||||
fs::path etag_path = old_path.string() + ".etag";
|
||||
|
||||
if (!fs::exists(old_path)) {
|
||||
if (file && fs::exists(file->final_path)) {
|
||||
return true;
|
||||
}
|
||||
LOG_WRN("%s: %s not found in old cache or HF cache\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!file) {
|
||||
LOG_WRN("%s: %s not found in current repo\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!sha256.empty() && !file->oid.empty() && sha256 != file->oid) {
|
||||
LOG_WRN("%s: %s is not up to date (sha256 mismatch)\n", __func__, old_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (file->size > 0) {
|
||||
size_t size = fs::file_size(old_path);
|
||||
if (size != file->size) {
|
||||
LOG_WRN("%s: %s has wrong size %zu (expected %zu)\n", __func__, old_filename.c_str(), size, file->size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
to_migrate.push_back({path, sha256, file->size, old_path, etag_path, file});
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool collect_files(const fs::path & old_cache,
|
||||
const std::string & owner,
|
||||
const std::string & repo,
|
||||
const nl::json & node,
|
||||
const hf_files & files,
|
||||
migrate_files & to_migrate) {
|
||||
|
||||
if (!node.contains("rfilename") ||
|
||||
!node.contains("lfs") ||
|
||||
!node["lfs"].contains("sha256")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string path = node["rfilename"];
|
||||
std::string sha256 = node["lfs"]["sha256"];
|
||||
|
||||
auto split = get_gguf_split_info(path);
|
||||
|
||||
if (split.count <= 1) {
|
||||
return collect_file(old_cache, owner, repo, path, sha256, files, to_migrate);
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> splits;
|
||||
|
||||
for (const auto & f : files) {
|
||||
auto split_f = get_gguf_split_info(f.path);
|
||||
if (split_f.count == split.count && split_f.prefix == split.prefix) {
|
||||
// sadly the manifest only provides the sha256 of the first file (index == 1)
|
||||
// the rest will be verified using the size...
|
||||
std::string f_sha256 = (split_f.index == 1) ? sha256 : "";
|
||||
splits.emplace_back(f.path, f_sha256);
|
||||
}
|
||||
}
|
||||
|
||||
if ((int)splits.size() != split.count) {
|
||||
LOG_WRN("%s: expected %d split files but found %d in repo\n", __func__, split.count, (int)splits.size());
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [f_path, f_sha256] : splits) {
|
||||
if (!collect_file(old_cache, owner, repo, f_path, f_sha256, files, to_migrate)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool migrate_file(const migrate_file & file) {
|
||||
std::error_code ec;
|
||||
|
||||
fs::path new_path(file.file->local_path);
|
||||
fs::create_directories(new_path.parent_path(), ec);
|
||||
|
||||
if (!fs::exists(new_path, ec)) {
|
||||
fs::rename(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
fs::copy_file(file.old_path, new_path, ec);
|
||||
if (ec) {
|
||||
LOG_ERR("%s: failed to move/copy %s: %s\n", __func__, file.old_path.string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
fs::remove(file.old_path, ec);
|
||||
}
|
||||
fs::remove(file.etag_path, ec);
|
||||
|
||||
std::string filename = finalize_file(*file.file);
|
||||
LOG_INF("%s: migrated %s -> %s\n", __func__, file.old_path.filename().string().c_str(), filename.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
|
||||
fs::path old_cache = fs_get_cache_directory();
|
||||
if (!fs::exists(old_cache)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (offline) {
|
||||
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
|
||||
return; // -hf is not going to work
|
||||
}
|
||||
|
||||
bool warned = false;
|
||||
|
||||
for (const auto & entry : fs::directory_iterator(old_cache)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
auto filename = entry.path().filename().string();
|
||||
auto [owner, repo] = parse_manifest_name(filename);
|
||||
|
||||
if (owner.empty() || repo.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!warned) {
|
||||
warned = true;
|
||||
LOG_WRN("================================================================================\n"
|
||||
"WARNING: Migrating cache to HuggingFace cache directory\n"
|
||||
" Old cache: %s\n"
|
||||
" New cache: %s\n"
|
||||
"This one-time migration moves models previously downloaded with -hf\n"
|
||||
"from the legacy llama.cpp cache to the standard HuggingFace cache.\n"
|
||||
"Models downloaded with --model-url are not affected.\n"
|
||||
"================================================================================\n",
|
||||
old_cache.string().c_str(), get_cache_directory().string().c_str());
|
||||
}
|
||||
|
||||
auto repo_id = owner + "/" + repo;
|
||||
auto files = get_repo_files(repo_id, token);
|
||||
|
||||
if (files.empty()) {
|
||||
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
migrate_files to_migrate;
|
||||
bool ok = true;
|
||||
|
||||
try {
|
||||
std::ifstream manifest(entry.path());
|
||||
auto json = nl::json::parse(manifest);
|
||||
for (const char * key : {"ggufFile", "mmprojFile"}) {
|
||||
if (json.contains(key)) {
|
||||
if (!collect_files(old_cache, owner, repo, json[key], files, to_migrate)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration skipped: one or more files failed validation\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto & file : to_migrate) {
|
||||
if (!migrate_file(file)) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LOG_WRN("%s: migration failed: could not migrate all files\n", __func__);
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG_INF("%s: migration complete, deleting manifest: %s\n", __func__, entry.path().string().c_str());
|
||||
fs::remove(entry.path());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace hf_cache
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ struct hf_file {
|
|||
std::string final_path;
|
||||
std::string oid;
|
||||
std::string repo_id;
|
||||
size_t size = 0; // only for the migration
|
||||
};
|
||||
|
||||
using hf_files = std::vector<hf_file>;
|
||||
|
|
@ -30,7 +29,4 @@ hf_files get_cached_files(const std::string & repo_id = {});
|
|||
// Create snapshot path (link or move/copy) and return it
|
||||
std::string finalize_file(const hf_file & file);
|
||||
|
||||
// TODO: Remove later
|
||||
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
|
||||
|
||||
} // namespace hf_cache
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
|||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,19 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
|
|||
{"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
|
||||
};
|
||||
|
||||
static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
|
||||
if (devices.empty()) {
|
||||
return "default";
|
||||
}
|
||||
|
||||
std::string result;
|
||||
for (size_t i = 0; i < devices.size(); i++) {
|
||||
if (i > 0) result += ", ";
|
||||
result += ggml_backend_dev_name(devices[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
struct common_speculative_config {
|
||||
common_speculative_type type;
|
||||
common_params_speculative params;
|
||||
|
|
@ -144,7 +157,7 @@ struct common_speculative_impl {
|
|||
|
||||
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
|
||||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
|
@ -167,6 +180,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
|||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
ctx_tgt ? "yes" : "no",
|
||||
ctx_dft ? "yes" : "no",
|
||||
common_speculative_get_devices_str(this->params.devices).c_str());
|
||||
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
|
||||
// TODO: optimize or pass from outside?
|
||||
|
|
@ -343,7 +366,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
|
|
@ -355,8 +378,12 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
|||
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
//common_params_speculative_eagle3 params;
|
||||
|
||||
common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {}
|
||||
common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
|
||||
// noop
|
||||
|
|
@ -371,7 +398,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
|||
// TODO: implement
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
|
|
@ -380,7 +407,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
|||
}
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
||||
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
|
||||
|
||||
llama_batch batch;
|
||||
|
|
@ -407,7 +434,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
|
||||
common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
|
|
@ -417,6 +444,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
|
||||
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
ctx_tgt ? "yes" : "no",
|
||||
ctx_dft ? "yes" : "no",
|
||||
common_speculative_get_devices_str(this->params.devices).c_str());
|
||||
|
||||
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
|
||||
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
|
||||
// llama_batch_init allocates only one of token/embd; MTP needs both.
|
||||
|
|
@ -427,7 +464,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
|
||||
sparams.top_k = 10;
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
|
||||
}
|
||||
|
|
@ -446,7 +483,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_state_draft_mtp() override {
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
if (batch.token != nullptr) {
|
||||
free(batch.token);
|
||||
batch.token = nullptr;
|
||||
|
|
@ -462,7 +499,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d — "
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
|
|
@ -633,6 +670,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
drafting[seq_id] = false;
|
||||
n_drafting--;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
auto & dp = dparams.at(seq_id);
|
||||
|
|
@ -678,7 +723,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override {
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -714,7 +759,12 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
|||
common_ngram_simple_config config)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
|
||||
, params(params.ngram_simple)
|
||||
, config(config) {}
|
||||
, config(config)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
|
||||
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
|
||||
this->params.size_n, this->params.size_m, this->params.min_hits);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
|
||||
// noop
|
||||
|
|
@ -738,7 +788,7 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
|
|
@ -748,20 +798,21 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
|||
};
|
||||
|
||||
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
common_params_speculative_ngram_map params;
|
||||
|
||||
// n_seq configs
|
||||
std::vector<common_ngram_map> config;
|
||||
|
||||
common_speculative_impl_ngram_map_k(
|
||||
const common_params_speculative & params,
|
||||
const common_ngram_map & config,
|
||||
uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
|
||||
, params(params.ngram_map_k) {
|
||||
{
|
||||
for (uint32_t i = 0; i < n_seq; i++) {
|
||||
this->config.push_back(config);
|
||||
}
|
||||
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
|
||||
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
|
||||
config.size_key, config.size_value, config.key_only, config.min_hits);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
|
|
@ -788,9 +839,13 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
|
||||
GGML_ASSERT((seq_id < (llama_seq_id) config.size()));
|
||||
|
||||
if (is_other) {
|
||||
return;
|
||||
}
|
||||
|
||||
common_ngram_map_accept(config[seq_id], n_accepted);
|
||||
}
|
||||
|
||||
|
|
@ -812,7 +867,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
|||
// the last position in the prompt that was added to the ngram container
|
||||
size_t i_last = 0;
|
||||
|
||||
// length of the last drafted n‑gram (number of tokens returned by draft)
|
||||
// length of the last drafted n-gram (number of tokens returned by draft)
|
||||
size_t n_draft_last = 0;
|
||||
|
||||
// consecutive accept rounds with low acceptance fraction (< 0.5)
|
||||
|
|
@ -830,8 +885,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
|||
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
|
||||
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
|
||||
|
||||
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
|
||||
this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
|
||||
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
|
||||
this->params.n_match, this->params.n_max, this->params.n_min);
|
||||
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
|
||||
mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
|
||||
if (this->params.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
|
|
@ -921,7 +979,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
|||
}
|
||||
result.resize(result.size() - n);
|
||||
|
||||
// store length of drafted n‑gram for later acceptance analysis
|
||||
// store length of drafted n-gram for later acceptance analysis
|
||||
sinfo.n_draft_last = result.size();
|
||||
}
|
||||
|
||||
|
|
@ -943,17 +1001,21 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
|
||||
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
|
||||
if (is_other) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto & sinfo = sinfos[seq_id];
|
||||
|
||||
// compute acceptance fraction if we have a recorded draft length
|
||||
if (sinfo.n_draft_last > 0) {
|
||||
const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last;
|
||||
if (f_acc < 0.5) {
|
||||
if (f_acc < 0.25) {
|
||||
sinfo.n_low++;
|
||||
if (sinfo.n_low >= 3) {
|
||||
if (sinfo.n_low >= 5) {
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) – resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
|
|
@ -1003,6 +1065,12 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
|||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
|
||||
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
|
||||
n_draft,
|
||||
path_static.empty() ? "none" : path_static.c_str(),
|
||||
path_dynamic.empty() ? "none" : path_dynamic.c_str());
|
||||
|
||||
sinfos.resize(n_seq);
|
||||
|
||||
if (!path_static.empty()) {
|
||||
|
|
@ -1099,7 +1167,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
|||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
|
|
@ -1285,7 +1353,6 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
|||
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
|
||||
|
||||
for (const common_speculative_config & config : configs) {
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str());
|
||||
switch (config.type) {
|
||||
case COMMON_SPECULATIVE_TYPE_NONE:
|
||||
break;
|
||||
|
|
@ -1298,7 +1365,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
|||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft_mtp>(config.params, n_seq));
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
|
|
@ -1319,11 +1386,16 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
|||
impls.push_back(std::move(state));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
|
||||
impls.push_back(
|
||||
std::make_unique<common_speculative_impl_ngram_map_k>(
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
|
||||
impls.push_back(
|
||||
std::make_unique<common_speculative_impl_ngram_map_k>(
|
||||
config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
|
||||
get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
|
||||
|
|
@ -1515,11 +1587,6 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
|||
|
||||
GGML_ASSERT(impl);
|
||||
|
||||
// TODO: currently only the implementation that generated the draft is used to accept it
|
||||
// however, some implementations (such as MTP) need to also "see" the accepted tokens
|
||||
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
|
||||
// inform the implementation if the accepted tokens are from another implementation and
|
||||
// pass the accepted tokens to all remaining implementations using `is_other == true`
|
||||
{
|
||||
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
|
||||
if (n_accepted > 0) {
|
||||
|
|
@ -1527,9 +1594,16 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
|||
impl->n_acc_tokens += n_accepted;
|
||||
}
|
||||
|
||||
impl->accept(seq_id, n_accepted);
|
||||
impl->accept(seq_id, n_accepted, false);
|
||||
impl->n_call_accept++;
|
||||
}
|
||||
|
||||
// accept with the rest of the implementations, using is_other == true
|
||||
for (auto & impl_other : spec->impls) {
|
||||
if (impl_other.get() != impl) {
|
||||
impl_other->accept(seq_id, n_accepted, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
|
|
@ -1549,7 +1623,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
|||
str_perf = "";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
|
||||
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
|
|
|
|||
|
|
@ -115,15 +115,15 @@ def parse_args() -> argparse.Namespace:
|
|||
)
|
||||
parser.add_argument(
|
||||
"--mmproj", action="store_true",
|
||||
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
|
||||
help="Export multimodal projector (mmproj) for vision models. This will only work on some vision models. An 'mmproj-' prefix will be added to the output file name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mtp", action="store_true",
|
||||
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
|
||||
help="Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. An 'mtp-' prefix will be added to the output file name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mtp", action="store_true",
|
||||
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
|
||||
help="Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, but even though the bundled default is more space-efficient overall, this allows differing quantization which may be more performant.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mistral-format", action="store_true",
|
||||
|
|
|
|||
|
|
@ -445,6 +445,11 @@ if __name__ == '__main__':
|
|||
if self.lazy:
|
||||
tensor = LazyTorchTensor.from_eager(tensor)
|
||||
base_name = get_base_tensor_name(name)
|
||||
# filter base name, ignore tensor transformations for now
|
||||
data_gen = lambda g=tensor: g # noqa: E731
|
||||
if (titem := self.filter_tensors((base_name, data_gen))) is None:
|
||||
continue
|
||||
base_name, _ = titem
|
||||
# note: mergekit-extract-lora also adds token embeddings to the adapter
|
||||
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
|
||||
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
|
||||
|
|
|
|||
|
|
@ -149,6 +149,8 @@ class TaskState:
|
|||
t_gen_ms: Optional[float] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
server_name: Optional[str] = None
|
||||
chunk_idx: int = 0
|
||||
problem_idx: int = 0
|
||||
|
||||
|
||||
class EvalState:
|
||||
|
|
@ -233,7 +235,9 @@ class EvalState:
|
|||
tps_gen: Optional[float] = None,
|
||||
t_gen_ms: Optional[float] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
server_name: Optional[str] = None
|
||||
server_name: Optional[str] = None,
|
||||
chunk_idx: int = 0,
|
||||
problem_idx: int = 0,
|
||||
):
|
||||
with self._lock:
|
||||
if "cases" not in self.task_states:
|
||||
|
|
@ -252,7 +256,9 @@ class EvalState:
|
|||
"tps_gen": tps_gen,
|
||||
"t_gen_ms": t_gen_ms,
|
||||
"reasoning_content": reasoning_content,
|
||||
"server_name": server_name
|
||||
"server_name": server_name,
|
||||
"chunk_idx": chunk_idx,
|
||||
"problem_idx": problem_idx,
|
||||
}
|
||||
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
|
|
@ -289,6 +295,9 @@ class EvalState:
|
|||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
question_text, prompt, expected = self.get_case(i)
|
||||
# Extract chunk_idx from task_id for pending cases
|
||||
_parts = task_id.rsplit("_", 2)
|
||||
_chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
|
||||
if task_id in self.task_states.get("cases", {}):
|
||||
all_cases[task_id] = self.task_states["cases"][task_id]
|
||||
else:
|
||||
|
|
@ -306,7 +315,9 @@ class EvalState:
|
|||
"tps_gen": None,
|
||||
"t_gen_ms": None,
|
||||
"reasoning_content": None,
|
||||
"server_name": None
|
||||
"server_name": None,
|
||||
"chunk_idx": _chunk_idx,
|
||||
"problem_idx": i,
|
||||
}
|
||||
|
||||
ci_lower, ci_upper = self.accuracy_ci()
|
||||
|
|
@ -382,11 +393,12 @@ class EvalState:
|
|||
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
|
||||
escaped_server = self._escape_html(server_name)
|
||||
|
||||
answer_class = status_class if status == "ok" else ""
|
||||
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
|
||||
<td>{task_id}</td>
|
||||
<td class="{status_class}">{status_text}</td>
|
||||
<td>{self._escape_html(expected)}</td>
|
||||
<td>{self._escape_html(answer)}</td>
|
||||
<td class="{answer_class}">{self._escape_html(answer)}</td>
|
||||
<td>{tokens_str}</td>
|
||||
<td>{tps_str}</td>
|
||||
<td>{t_gen_str}</td>
|
||||
|
|
@ -405,6 +417,53 @@ class EvalState:
|
|||
|
||||
rows_html = "\n".join(rows)
|
||||
|
||||
# ---- per-problem summary table ----
|
||||
problem_groups: Dict[int, List[Dict[str, Any]]] = {}
|
||||
for _tid, _case in cases.items():
|
||||
if _case.get("status") != "ok":
|
||||
continue
|
||||
_pidx = _case.get("problem_idx")
|
||||
if _pidx is None:
|
||||
_p_parts = _tid.rsplit("_", 2)
|
||||
_pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0
|
||||
problem_groups.setdefault(_pidx, []).append(_case)
|
||||
|
||||
summary_rows_html = ""
|
||||
if problem_groups:
|
||||
def _stat(v, fmt=".1f", avg_fmt=None):
|
||||
if not v:
|
||||
return ("–", "–", "–")
|
||||
af = fmt if avg_fmt is None else avg_fmt
|
||||
return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}")
|
||||
|
||||
summary_data = []
|
||||
for pidx, g in problem_groups.items():
|
||||
runs = len(g)
|
||||
n_ok = sum(1 for c in g if c.get("correct", False))
|
||||
toks = [c["tokens"] for c in g if c.get("tokens") is not None]
|
||||
tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None]
|
||||
tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None]
|
||||
summary_data.append((
|
||||
pidx, runs, n_ok,
|
||||
_stat(toks, "d", ".0f"),
|
||||
_stat(tps),
|
||||
_stat(tg),
|
||||
))
|
||||
|
||||
summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending
|
||||
|
||||
summary_rows_html = "\n".join(
|
||||
f"""<tr class="summary-row">
|
||||
<td>{p:03d}</td>
|
||||
<td>{r}</td>
|
||||
<td>{n}/{r}</td>
|
||||
<td>{tk[0]}</td><td>{tk[1]}</td><td>{tk[2]}</td>
|
||||
<td>{tp[0]}</td><td>{tp[1]}</td><td>{tp[2]}</td>
|
||||
<td>{tg[0]}</td><td>{tg[1]}</td><td>{tg[2]}</td>
|
||||
</tr>"""
|
||||
for p, r, n, tk, tp, tg in summary_data
|
||||
)
|
||||
|
||||
html_content = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
|
|
@ -412,10 +471,10 @@ class EvalState:
|
|||
<title>{self.dataset_type.upper()} Eval</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui, sans-serif; margin: 0; padding: 16px; background: #fff; color: #222; }}
|
||||
.bar {{ padding: 8px 0; font-size: 14px; color: #555; }}
|
||||
.bar span {{ margin-right: 20px; }}
|
||||
.bar b {{ color: #222; }}
|
||||
table {{ width: 100%; border-collapse: collapse; font-size: 13px; }}
|
||||
.bar {{ padding: 8px 0; font-size: 13px; color: #555; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; display: grid; grid-template-columns: auto 1fr auto 1fr; gap: 2px 12px; align-items: baseline; }}
|
||||
.bar .label {{ color: #888; }}
|
||||
.bar .value {{ color: #222; }}
|
||||
table {{ width: 100%; border-collapse: collapse; font-size: 13px; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; }}
|
||||
th {{ text-align: left; padding: 6px 8px; border-bottom: 2px solid #ccc; font-weight: 600; }}
|
||||
td {{ padding: 4px 8px; border-bottom: 1px solid #eee; vertical-align: top; }}
|
||||
.task-row {{ cursor: pointer; }}
|
||||
|
|
@ -429,37 +488,88 @@ class EvalState:
|
|||
.details-content {{ padding: 8px 16px; background: #f6f8fa; font-size: 12px; }}
|
||||
.details-content b {{ color: #555; }}
|
||||
.details-content pre {{ background: #fff; border: 1px solid #e1e4e8; padding: 8px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; margin: 4px 0 8px; }}
|
||||
.summary-table {{ margin-bottom: 16px; font-size: 13px; width: 100%; }}
|
||||
.summary-row {{ background: #fafbfc; }}
|
||||
.summary-row:hover {{ background: #f5f5f5; }}
|
||||
.summary-table th {{ text-align: right; font-weight: 600; }}
|
||||
.summary-table th:first-child {{ text-align: left; }}
|
||||
.summary-table th[colspan] {{ text-align: center; }}
|
||||
.summary-table td {{ text-align: right; }}
|
||||
.summary-table td:first-child {{ text-align: left; }}
|
||||
.tabs {{ display: flex; border-bottom: 2px solid #ddd; margin: 12px 0 0; }}
|
||||
.tab-btn {{ padding: 6px 16px; border: none; background: none; font-size: 13px; cursor: pointer; color: #555; border-bottom: 2px solid transparent; margin-bottom: -2px; font-weight: 500; }}
|
||||
.tab-btn:hover {{ color: #222; }}
|
||||
.tab-btn.active {{ color: #222; border-bottom-color: #222; font-weight: 600; }}
|
||||
.tab-content {{ display: none; }}
|
||||
.tab-content.active {{ display: block; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="bar">
|
||||
<span><b>{self.dataset_type.upper()}</b></span>
|
||||
<span>Model: {self.model_name or 'N/A'}</span>
|
||||
<span>Accuracy: <b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</span>
|
||||
<span>Correct: <span class="correct">{n_correct}</span> / {len(completed)}</span>
|
||||
<span>Pending: {n_pending}</span>
|
||||
<span>Time: {self.total_time:.1f}s</span>
|
||||
<span>Sampling: {sampling_str}</span>
|
||||
<div class="label">Dataset</div><div class="value"><b>{self.dataset_type.upper()}</b></div>
|
||||
<div class="label">Model</div><div class="value"><b>{self.model_name or 'N/A'}</b></div>
|
||||
<div class="label">Accuracy</div><div class="value"><b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</div>
|
||||
<div class="label">Correct</div><div class="value"><span class="correct">{n_correct}</span> / {len(completed)}</div>
|
||||
<div class="label">Pending</div><div class="value">{n_pending}</div>
|
||||
<div class="label">Time</div><div class="value">{self.total_time:.1f}s</div>
|
||||
<div class="label">Sampling</div><div class="value">{sampling_str}</div>
|
||||
</div>
|
||||
<div class="tabs">
|
||||
<button class="tab-btn active" data-tab="detailed" onclick="switchTab(this)">Detailed</button>
|
||||
<button class="tab-btn" data-tab="summary" onclick="switchTab(this)">Summary</button>
|
||||
</div>
|
||||
<div id="tab-detailed" class="tab-content active">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th></th>
|
||||
<th>Gold</th>
|
||||
<th>Answer</th>
|
||||
<th>Tokens</th>
|
||||
<th>T/s</th>
|
||||
<th>Gen s</th>
|
||||
<th>Server</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div id="tab-summary" class="tab-content">
|
||||
<table class="summary-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Problem</th>
|
||||
<th>Runs</th>
|
||||
<th>Correct</th>
|
||||
<th colspan="3">Tokens</th>
|
||||
<th colspan="3">T/s</th>
|
||||
<th colspan="3">Gen s</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th></th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
<th>min</th><th>avg</th><th>max</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{summary_rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th></th>
|
||||
<th>Gold</th>
|
||||
<th>Answer</th>
|
||||
<th>Tokens</th>
|
||||
<th>T/s</th>
|
||||
<th>Gen s</th>
|
||||
<th>Server</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows_html}
|
||||
</tbody>
|
||||
</table>
|
||||
<script>
|
||||
function toggleDetails(id) {{ document.getElementById('details-'+id).classList.toggle('open'); }}
|
||||
function switchTab(btn) {{
|
||||
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
|
||||
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
|
||||
btn.classList.add('active');
|
||||
document.getElementById('tab-'+btn.dataset.tab).classList.add('active');
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
|
@ -1062,12 +1172,19 @@ class Processor:
|
|||
) -> TaskState:
|
||||
question_text, prompt, expected = eval_state.get_case(i)
|
||||
|
||||
# Extract chunk_idx from task_id: "{dataset_type}_{chunk_idx:03d}_{index:03d}"
|
||||
_parts = task_id.rsplit("_", 2)
|
||||
chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
|
||||
problem_idx = i
|
||||
|
||||
task_state = TaskState(
|
||||
task_id=task_id,
|
||||
prompt=prompt,
|
||||
expected=expected,
|
||||
question_text=question_text,
|
||||
server_name=server_config.name
|
||||
server_name=server_config.name,
|
||||
chunk_idx=chunk_idx,
|
||||
problem_idx=problem_idx,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -1085,7 +1202,8 @@ class Processor:
|
|||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, None,
|
||||
{"finish_reason": finish_reason}, False, task_state.status,
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
|
||||
chunk_idx, problem_idx,
|
||||
)
|
||||
eval_state.dump()
|
||||
return task_state
|
||||
|
|
@ -1108,7 +1226,8 @@ class Processor:
|
|||
eval_state.add_result(
|
||||
task_id, prompt, expected, result, answer,
|
||||
grader_log, is_correct, "ok",
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
|
||||
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
|
||||
chunk_idx, problem_idx,
|
||||
)
|
||||
|
||||
eval_state.dump()
|
||||
|
|
|
|||
|
|
@ -65,34 +65,70 @@ def normalize_number(s: str) -> Optional[int]:
|
|||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
def __init__(self, split: str = "train", dataset_type: str = "aime"):
|
||||
self.split = split
|
||||
self.dataset_type = dataset_type
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
def _get_question_text(self, question: Dict) -> str:
|
||||
"""Get question text, handling different dataset field names."""
|
||||
return question.get("problem", question.get("question", ""))
|
||||
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
def _load_dataset(self):
|
||||
if self.dataset_type == "aime":
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
elif self.dataset_type == "aime2025":
|
||||
print(f"Loading AIME2025 dataset...")
|
||||
ds_list = []
|
||||
for config_name in ["AIME2025-I", "AIME2025-II"]:
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
|
||||
ds_list.extend(ds)
|
||||
ds = ds_list
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
||||
|
||||
self.questions = list(ds)
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def find_question(self, request_text: str) -> Optional[Dict]:
|
||||
# Strip common template prefixes to get the actual question text
|
||||
# Templates include things like "Solve the following math problem step by step..."
|
||||
# The actual question usually follows a blank line or after the template instruction
|
||||
cleaned = request_text
|
||||
# Split on double newline and take the part that looks like the problem
|
||||
parts = cleaned.split('\n\n')
|
||||
if len(parts) > 1:
|
||||
# Find the part that's longest (likely the actual problem text)
|
||||
problem_parts = [p for p in parts if len(p.strip()) > 100]
|
||||
if problem_parts:
|
||||
cleaned = max(problem_parts, key=lambda x: len(x))
|
||||
|
||||
best_match = None
|
||||
best_distance = -1
|
||||
best_index = -1
|
||||
|
||||
for i, question in enumerate(self.questions):
|
||||
question_text = question["problem"]
|
||||
request_lower = request_text.lower()
|
||||
question_text = self._get_question_text(question)
|
||||
request_lower = cleaned.lower()
|
||||
question_lower = question_text.lower()
|
||||
|
||||
# Check if question text is contained in the cleaned request
|
||||
if question_lower in request_lower or request_lower in question_lower:
|
||||
debug_log(f"DEBUG: Found substring match at index {i}")
|
||||
return question
|
||||
|
||||
# Exact match
|
||||
if question_lower == request_lower:
|
||||
debug_log(f"DEBUG: Found exact match at index {i}")
|
||||
|
|
@ -118,7 +154,7 @@ class AimeDataset:
|
|||
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
||||
return best_match
|
||||
|
||||
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
|
||||
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
|
||||
return None
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
|
|
@ -134,15 +170,16 @@ class Simulator:
|
|||
port: int = 8033,
|
||||
host: str = "localhost",
|
||||
success_rate: float = 0.8,
|
||||
dataset_split: str = "train"
|
||||
dataset_split: str = "train",
|
||||
dataset_type: str = "aime"
|
||||
):
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.success_rate = success_rate
|
||||
self.dataset = AimeDataset(dataset_split)
|
||||
self.dataset = AimeDataset(dataset_split, dataset_type)
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
id=dataset_type,
|
||||
tasks=[dataset_type],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": 2048}
|
||||
)
|
||||
|
|
@ -159,6 +196,10 @@ class Simulator:
|
|||
else:
|
||||
response_text = self._generate_wrong_answer(question)
|
||||
|
||||
comp_tokens = random.randint(10000, 60000)
|
||||
tps_gen = random.uniform(90.0, 110.0)
|
||||
t_gen_ms = comp_tokens / tps_gen * 1000
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
|
|
@ -176,8 +217,12 @@ class Simulator:
|
|||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
"completion_tokens": comp_tokens,
|
||||
"total_tokens": 100 + comp_tokens
|
||||
},
|
||||
"timings": {
|
||||
"predicted_ms": t_gen_ms,
|
||||
"predicted_per_second": tps_gen
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -218,6 +263,12 @@ class Simulator:
|
|||
return response
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == "/v1/models":
|
||||
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
|
||||
return
|
||||
self._send_json({"error": "Not found"}, 404)
|
||||
|
||||
def do_POST(self):
|
||||
if self.path != "/v1/chat/completions":
|
||||
self._send_json({"error": "Not found"}, 404)
|
||||
|
|
@ -280,6 +331,13 @@ def main():
|
|||
default=0.8,
|
||||
help="Success rate 0-1 (default: 0.8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="aime",
|
||||
choices=["aime", "aime2025"],
|
||||
help="Dataset type (default: aime)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
|
|
@ -294,7 +352,8 @@ def main():
|
|||
port=args.port,
|
||||
host=args.host,
|
||||
success_rate=args.success_rate,
|
||||
dataset_split=args.dataset_split
|
||||
dataset_split=args.dataset_split,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
|
||||
server = HTTPServer((args.host, args.port), RequestHandler)
|
||||
|
|
@ -304,7 +363,7 @@ def main():
|
|||
print("\n=== llama-server-simulator ===")
|
||||
print(f"Server running on http://{args.host}:{args.port}")
|
||||
print(f"Success rate: {args.success_rate}")
|
||||
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return 8;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 2;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return 8;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
|
|||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
|
||||
// note: this is slower
|
||||
//const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0;
|
||||
const bool is_c4 = false;
|
||||
|
||||
snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
|
|
@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
|
|||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nth = MIN(args.ne00, nth_max);
|
||||
|
||||
const int nk0 = (args.ne00 + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
|
@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
nk0 = ne00/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
int nth = std::min<int>(nk0*ne01, 256);
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
|
@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (nrptg*nth > 256) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
|
|
@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
|
||||
|
||||
const int nth = std::min(1024, ne0);
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
const int nth = MIN(args.ne0, nth_max);
|
||||
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl(
|
|||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1u) {
|
||||
if (K > 1) {
|
||||
const int target_slot = (int)t - shift;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
|
|
@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl(
|
|||
}
|
||||
}
|
||||
|
||||
if (K == 1u) {
|
||||
if (K == 1) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
|
|
@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32(
|
|||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
|
|
@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32(
|
|||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
|
|
@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_pad_f32(
|
||||
template <typename T>
|
||||
kernel void kernel_pad_impl(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t k0 = tgpig.x/args.ne1;
|
||||
const int32_t i1 = tgpig.x - k0*args.ne1;
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
const int32_t i03 = i3;
|
||||
const int32_t i02 = i2;
|
||||
const int32_t i01 = i1;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1;
|
||||
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
if (i0 < args.ne00) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
||||
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
||||
if (i0 >= args.ne0) {
|
||||
break;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
||||
|
||||
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
||||
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
||||
|
||||
// TODO: this is slow - optimize
|
||||
kernel void kernel_pad_reflect_1d_f32(
|
||||
constant ggml_metal_kargs_pad_reflect_1d & args,
|
||||
device const char * src0,
|
||||
|
|
@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t(
|
|||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
||||
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
break;
|
||||
|
|
@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q(
|
|||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
|
||||
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
||||
|
||||
quantize_func(src, dst_data[i00]);
|
||||
|
|
@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32(
|
|||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
||||
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
T4x4 temp;
|
||||
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
||||
dst_data[i00] = temp;
|
||||
|
|
|
|||
|
|
@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|||
return &guid;
|
||||
}
|
||||
|
||||
struct ggml_backend_rpc_device_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
uint64_t last_graph_uid;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_buffer_type_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
|
|
@ -211,7 +219,6 @@ struct ggml_backend_rpc_context {
|
|||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
uint64_t last_graph_uid;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
|
|
@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
|
|||
|
||||
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context;
|
||||
|
||||
GGML_ASSERT(cgraph->n_nodes > 0);
|
||||
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
|
||||
bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid;
|
||||
if (reuse) {
|
||||
rpc_msg_graph_recompute_req request;
|
||||
request.device = rpc_ctx->device;
|
||||
|
|
@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
} else {
|
||||
rpc_ctx->last_graph_uid = cgraph->uid;
|
||||
rpc_dev_ctx->last_graph_uid = cgraph->uid;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
|
|
@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
|||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name,
|
||||
/* .last_graph_uid = */ 0,
|
||||
};
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
|
|
@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
}
|
||||
}
|
||||
|
||||
// device interface
|
||||
|
||||
struct ggml_backend_rpc_device_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
||||
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
|
||||
|
|
@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
|
|||
std::string dev_name = "RPC" + std::to_string(dev_id);
|
||||
std::string dev_desc = std::string(endpoint);
|
||||
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ ind,
|
||||
/* .name = */ dev_name,
|
||||
/* .description = */ dev_desc
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ ind,
|
||||
/* .name = */ dev_name,
|
||||
/* .description = */ dev_desc,
|
||||
/* .last_graph_uid = */ 0,
|
||||
};
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
|
|
|
|||
|
|
@ -581,7 +581,8 @@ struct llm_graph_params {
|
|||
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
|
||||
(
|
||||
(!ubatch.token && !other.ubatch.token) ||
|
||||
(!ubatch.embd && !other.ubatch.embd)
|
||||
(!ubatch.embd && !other.ubatch.embd) ||
|
||||
(ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd)
|
||||
);
|
||||
|
||||
// when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same
|
||||
|
|
|
|||
|
|
@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
if (mem_recr->n_rs_seq > 0) {
|
||||
// [TAG_RECURRENT_ROLLBACK_SPLITS]
|
||||
// TODO: recurrent state rollback does not support equal splits
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
}
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
if (mem_recr->n_rs_seq > 0) {
|
||||
// [TAG_RECURRENT_ROLLBACK_SPLITS]
|
||||
// TODO: recurrent state rollback does not support equal splits
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
}
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -416,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
if (n_rs_seq > 0) {
|
||||
// [TAG_RECURRENT_ROLLBACK_SPLITS]
|
||||
// TODO: recurrent state rollback does not support equal splits
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ public:
|
|||
|
||||
// number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
|
||||
uint32_t n_rs_seq = 0;
|
||||
|
||||
// per-seq rollback index
|
||||
std::vector<uint32_t> rs_idx;
|
||||
|
||||
|
|
|
|||
|
|
@ -447,13 +447,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
return build_delta_net_chunking(q, k, v, g, b, s, il);
|
||||
}
|
||||
|
||||
bool llm_build_delta_net_base::keep_rs() const {
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
return cparams.n_rs_seq > 0
|
||||
&& n_seq_tokens > 1
|
||||
&& (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_delta_net_base::build_conv_state(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * conv_states_all,
|
||||
|
|
@ -461,12 +454,12 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
|
|||
int64_t conv_kernel_size,
|
||||
int64_t conv_channels,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
const uint32_t mem_size = mctx_cur->get_size();
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const bool keep = keep_rs();
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
const auto mem_size = mctx_cur->get_size();
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
|
|
@ -480,32 +473,52 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
|
|||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
|
||||
if (!keep) {
|
||||
ggml_tensor * last_conv_states =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
|
||||
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
|
||||
cb(last_conv_states, "last_conv_states", il);
|
||||
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
|
||||
|
||||
ggml_tensor * state_update_target =
|
||||
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
|
||||
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
|
||||
cb(state_update_target, "state_update_target", il);
|
||||
const size_t row_size = ggml_row_size(conv_states_all->type, row_count);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
if (cparams.n_rs_seq == 0) {
|
||||
const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0];
|
||||
const int64_t s_slot = 0;
|
||||
|
||||
ggml_tensor * conv_state_last =
|
||||
ggml_view_3d(ctx0, conv_input,
|
||||
conv_kernel_size - 1, conv_channels, n_seqs,
|
||||
conv_input->nb[1], conv_input->nb[2],
|
||||
ggml_row_size(conv_input->type, s_idx));
|
||||
cb(conv_state_last, "conv_state_last", il);
|
||||
|
||||
ggml_tensor * conv_state_update =
|
||||
ggml_view_2d(ctx0, conv_states_all,
|
||||
row_count, n_seqs, conv_states_all->nb[1],
|
||||
(s_slot * mem_size + kv_head) * row_size);
|
||||
cb(conv_state_update, "conv_state_update", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
|
||||
} else {
|
||||
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
|
||||
const size_t row_size = row_count * ggml_element_size(conv_states_all);
|
||||
for (int64_t t = 1; t <= n_seq_tokens; ++t) {
|
||||
const uint32_t slot = (uint32_t)(n_seq_tokens - t);
|
||||
ggml_tensor * src =
|
||||
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
|
||||
conv_input->nb[1], conv_input->nb[2],
|
||||
t * ggml_element_size(conv_input));
|
||||
ggml_tensor * dst =
|
||||
ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
|
||||
conv_states_all->nb[1],
|
||||
((size_t) slot * mem_size + kv_head) * row_size);
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
// [TAG_RECURRENT_ROLLBACK_SPLITS]
|
||||
// TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are
|
||||
// inside the same ubatch. currently with `split_equal()` this is not correct
|
||||
|
||||
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
|
||||
|
||||
for (int64_t t = 1; t <= K; ++t) {
|
||||
const int64_t s_idx = std::max<int64_t>(0, conv_input->ne[0] - conv_states->ne[0] - K + t);
|
||||
const int64_t s_slot = K - t;
|
||||
|
||||
ggml_tensor * conv_state_last =
|
||||
ggml_view_3d(ctx0, conv_input,
|
||||
conv_kernel_size - 1, conv_channels, n_seqs,
|
||||
conv_input->nb[1], conv_input->nb[2],
|
||||
ggml_row_size(conv_input->type, s_idx));
|
||||
|
||||
ggml_tensor * conv_state_update =
|
||||
ggml_view_2d(ctx0,
|
||||
conv_states_all, row_count, n_seqs,
|
||||
conv_states_all->nb[1],
|
||||
(s_slot * mem_size + kv_head) * row_size);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -531,7 +544,9 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
|||
const int64_t n_seqs = s->ne[3];
|
||||
const int64_t n_seq_tokens = q->ne[2];
|
||||
|
||||
if (!keep_rs()) {
|
||||
const bool keep = cparams.n_rs_seq > 0;
|
||||
|
||||
if (!keep) {
|
||||
auto attn_out = build_delta_net(q, k, v, g, b, s, il);
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
|
|
@ -547,14 +562,18 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
|||
}
|
||||
|
||||
const int64_t D = S_v * S_v * H_v;
|
||||
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
|
||||
const int64_t K = cparams.n_rs_seq + 1;
|
||||
|
||||
// TODO: remove pad + simplify
|
||||
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
|
||||
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
|
||||
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
|
||||
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
|
||||
|
||||
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
|
||||
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
|
||||
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
|
||||
if (n_seq_tokens > 1) {
|
||||
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
|
||||
} else {
|
||||
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il);
|
||||
}
|
||||
|
||||
const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
|
||||
const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
|
||||
|
|
@ -576,9 +595,11 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
|||
ggml_row_size(gdn_out->type, S_v * S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
|
||||
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
|
||||
|
||||
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
|
||||
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
((size_t) cache_slot * mem_size + kv_head) * row_size);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,9 +66,6 @@ struct llm_build_delta_net_base : public llm_graph_context {
|
|||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// true when speculative rollback is enabled and the batch fits in the rs cache
|
||||
bool keep_rs() const;
|
||||
|
||||
// read conv state from cache, concat with qkv_mixed, write back (single slot or per-token)
|
||||
// qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)
|
||||
ggml_tensor * build_conv_state(
|
||||
|
|
|
|||
|
|
@ -496,7 +496,8 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
|||
const int64_t n_embd_head = hparams.n_embd_head_v();
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
|
||||
|
||||
// The MTP block lives at the source file's original layer index.
|
||||
// hparams.n_layer includes both main model layers and MTP layers. The MTP
|
||||
// layer is stored immediately after the main layers in model.layers[].
|
||||
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ int main(int argc, char ** argv) {
|
|||
if (!params.fit_params_print) {
|
||||
const common_params_fit_status status = common_fit_params(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
if (status != COMMON_PARAMS_FIT_STATUS_SUCCESS) {
|
||||
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
|
||||
exit(1);
|
||||
|
|
|
|||
|
|
@ -467,20 +467,26 @@ struct server_slot {
|
|||
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
SLT_INF(*this,
|
||||
"\n"
|
||||
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
||||
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second);
|
||||
|
||||
SLT_INF(*this,
|
||||
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
t_token_generation, n_decoded, t_gen, n_gen_second);
|
||||
|
||||
SLT_INF(*this,
|
||||
" total time = %10.2f ms / %5d tokens\n",
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
||||
t_token_generation, n_decoded, t_gen, n_gen_second,
|
||||
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
||||
|
||||
SLT_INF(*this,
|
||||
" graphs reused = %10d\n",
|
||||
llama_perf_context(ctx_tgt).n_reused);
|
||||
|
||||
if (n_draft_total > 0) {
|
||||
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
|
||||
SLT_CNT(*this,
|
||||
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total
|
||||
);
|
||||
SLT_INF(*this,
|
||||
"draft acceptance = %0.5f (%5d accepted / %5d generated)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total);
|
||||
}
|
||||
|
||||
common_speculative_print_stats(spec);
|
||||
|
|
@ -2583,9 +2589,9 @@ private:
|
|||
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa - 1);
|
||||
|
||||
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
if (n_past > 0 && n_past <= slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
|
|
@ -3885,6 +3891,7 @@ void server_routes::init_routes() {
|
|||
{ "eos_token", meta->eos_token_str },
|
||||
{ "build_info", meta->build_info },
|
||||
{ "is_sleeping", queue_tasks.is_sleeping() },
|
||||
{ "cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy },
|
||||
};
|
||||
if (params.use_jinja) {
|
||||
if (!tmpl_tools.empty()) {
|
||||
|
|
|
|||
|
|
@ -1165,6 +1165,7 @@ void server_models_routes::init_routes() {
|
|||
// Deprecated: use ui_settings instead (kept for backward compat)
|
||||
{"webui_settings", webui_settings},
|
||||
{"build_info", std::string(llama_build_info())},
|
||||
{"cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy},
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
|
|
|||
2
tools/ui/.env.example
Normal file
2
tools/ui/.env.example
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
VITE_PUBLIC_APP_NAME='llama-ui'
|
||||
# VITE_DEBUG='true'
|
||||
2
tools/ui/.gitignore
vendored
2
tools/ui/.gitignore
vendored
|
|
@ -25,4 +25,4 @@ vite.config.ts.timestamp-*
|
|||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
*.code-workspace
|
||||
*.code-workspace
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ export default ts.config(
|
|||
prettier,
|
||||
...svelte.configs.prettier,
|
||||
{
|
||||
languageOptions: {
|
||||
globals: { ...globals.browser, ...globals.node }
|
||||
},
|
||||
languageOptions: { globals: { ...globals.browser, ...globals.node } },
|
||||
rules: {
|
||||
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
|
||||
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
|
||||
|
|
@ -30,6 +28,7 @@ export default ts.config(
|
|||
'svelte/no-at-html-tags': 'off',
|
||||
// This app uses hash-based routing (#/) where resolve() from $app/paths does not apply
|
||||
'svelte/no-navigation-without-resolve': 'off',
|
||||
|
||||
// Enforce empty line at end of file
|
||||
'eol-last': 'error'
|
||||
}
|
||||
|
|
|
|||
60
tools/ui/package-lock.json
generated
60
tools/ui/package-lock.json
generated
|
|
@ -2307,9 +2307,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@sveltejs/kit": {
|
||||
"version": "2.59.1",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.59.1.tgz",
|
||||
"integrity": "sha512-d8OON70AphLdDesuTIl//M2O6fRTIicX8aYv8vhCiYEhTTI2OboKqey0Hu1A4VFhqwgqtq0vKDmPFGkw8kKmgw==",
|
||||
"version": "2.60.1",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.60.1.tgz",
|
||||
"integrity": "sha512-mQjlkNo+rJvpln7V2IGY2j99BqhcFbS4UN0AQNKNYfhBAFZTuCDAdW3a1sgf330mvtNvsBXn3HpAhcmvdJTcIQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -2318,7 +2318,7 @@
|
|||
"@types/cookie": "^0.6.0",
|
||||
"acorn": "^8.14.1",
|
||||
"cookie": "^0.6.0",
|
||||
"devalue": "^5.6.4",
|
||||
"devalue": "^5.8.1",
|
||||
"esm-env": "^1.2.2",
|
||||
"kleur": "^4.1.5",
|
||||
"magic-string": "^0.30.5",
|
||||
|
|
@ -4296,9 +4296,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/devalue": {
|
||||
"version": "5.6.4",
|
||||
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.6.4.tgz",
|
||||
"integrity": "sha512-Gp6rDldRsFh/7XuouDbxMH3Mx8GMCcgzIb1pDTvNyn8pZGQ22u+Wa+lGV9dQCltFQ7uVw0MhRyb8XDskNFOReA==",
|
||||
"version": "5.8.1",
|
||||
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.8.1.tgz",
|
||||
"integrity": "sha512-4CXDYRBGqN+57wVJkuXBYmpAVUSg3L6JAQa/DFqm238G73E1wuyc/JhGQJzN7vUf/CMphYau2zXbfWzDR5aTEw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/devlop": {
|
||||
|
|
@ -4856,12 +4856,12 @@
|
|||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "8.5.0",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.0.tgz",
|
||||
"integrity": "sha512-XKhFohWaSBdVJNTi5TaHziqnPkv04I9UQV6q1Wy7Ui6GGQZVW12ojDFwqer14EvCXxjvPG0CyWXx7cAXpALB4Q==",
|
||||
"version": "8.5.2",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.2.tgz",
|
||||
"integrity": "sha512-5Kb34ipNX694DH48vN9irak1Qx30nb0PLYHXfJgw4YEjiC3ZEmZJhwOp+VfiCYwFzvFTdB9QkArYS5kXa2cx2A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.1.0"
|
||||
"ip-address": "^10.2.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
|
|
@ -4909,9 +4909,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/fast-uri": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz",
|
||||
"integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==",
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.2.tgz",
|
||||
"integrity": "sha512-rVjf7ArG3LTk+FS6Yw81V1DLuZl1bRbNrev6Tmd/9RaroeeRRJhAt7jg/6YFxbvAQXUCavSoZhPPj6oOx+5KjQ==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
|
|
@ -5541,9 +5541,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.14",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.14.tgz",
|
||||
"integrity": "sha512-am5zfg3yu6sqn5yjKBNqhnTX7Cv+m00ox+7jbaKkrLMRJ4rAdldd1xPd/JzbBWspqaQv6RSTrgFN95EsfhC+7w==",
|
||||
"version": "4.12.19",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.19.tgz",
|
||||
"integrity": "sha512-xa3eYXYXx68XTT4hZ7dRzsXBhaq85ToSrlUJNoR0gwz/1Ap/CNwX47wfvV7pc/xWhjKVVkLT7zBJy8chhNguqQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
|
@ -5722,9 +5722,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.1.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
|
||||
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
|
||||
"version": "10.2.0",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz",
|
||||
"integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
|
|
@ -6008,9 +6008,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.22",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.22.tgz",
|
||||
"integrity": "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg==",
|
||||
"version": "0.16.47",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.47.tgz",
|
||||
"integrity": "sha512-Eeo8Ys1doU1z+x8AZsPpQu+p/QcZBI5PeOo7QGQdy2x2m0MU/hYagBbGOmXwr5KVbEfVuWv9LpnQWeehogurjg==",
|
||||
"dev": true,
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
|
|
@ -9245,9 +9245,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/svelte": {
|
||||
"version": "5.55.1",
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.1.tgz",
|
||||
"integrity": "sha512-QjvU7EFemf6mRzdMGlAFttMWtAAVXrax61SZYHdkD6yoVGQ89VeyKfZD4H1JrV1WLmJBxWhFch9H6ig/87VGjw==",
|
||||
"version": "5.55.7",
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.7.tgz",
|
||||
"integrity": "sha512-ymI5ykLPwIHW839E053FQbI1G+jnRFJEw3Kv5Y4njixVWywQBx+NUFpkkKyk5LIb36Fg9DVXSYpqiGekLD0hyw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
|
|
@ -9259,7 +9259,7 @@
|
|||
"aria-query": "5.3.1",
|
||||
"axobject-query": "^4.1.0",
|
||||
"clsx": "^2.1.1",
|
||||
"devalue": "^5.6.4",
|
||||
"devalue": "^5.8.1",
|
||||
"esm-env": "^1.2.1",
|
||||
"esrap": "^2.2.4",
|
||||
"is-reference": "^3.0.3",
|
||||
|
|
@ -10606,9 +10606,9 @@
|
|||
"license": "ISC"
|
||||
},
|
||||
"node_modules/ws": {
|
||||
"version": "8.18.3",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz",
|
||||
"integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==",
|
||||
"version": "8.20.1",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.20.1.tgz",
|
||||
"integrity": "sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
@import 'tailwindcss';
|
||||
@source ".";
|
||||
|
||||
@source '.';
|
||||
@plugin '@tailwindcss/forms';
|
||||
@plugin '@tailwindcss/typography';
|
||||
@import 'tw-animate-css';
|
||||
|
||||
@custom-variant dark (&:is(.dark *));
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
{#if videoSrc}
|
||||
<video controls class="mb-4 w-full" src={videoSrc}>
|
||||
<track kind="captions" src="" />
|
||||
Your browser does not support the video element.
|
||||
</video>
|
||||
{:else}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import { activeMessages } from '$lib/stores/conversations.svelte';
|
||||
|
||||
interface Props {
|
||||
currentModel?: string;
|
||||
disabled?: boolean;
|
||||
forceForegroundText?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
|
|
@ -20,7 +19,6 @@
|
|||
}
|
||||
|
||||
let {
|
||||
currentModel,
|
||||
disabled = false,
|
||||
forceForegroundText = false,
|
||||
hasAudioModality = $bindable(false),
|
||||
|
|
@ -41,14 +39,28 @@
|
|||
|
||||
let lastSyncedConversationModel: string | null = null;
|
||||
|
||||
let selectorModel = $derived(conversationModel ?? modelsStore.selectedModelName ?? null);
|
||||
|
||||
$effect(() => {
|
||||
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
|
||||
lastSyncedConversationModel = conversationModel;
|
||||
if (modelOptions().some((m) => m.model === conversationModel)) {
|
||||
modelsStore.selectedModelName = conversationModel;
|
||||
modelsStore.selectModelByName(conversationModel);
|
||||
} else {
|
||||
modelsStore.selectedModelName = null;
|
||||
modelsStore.clearSelection();
|
||||
}
|
||||
|
||||
modelsStore.selectModelByName(conversationModel);
|
||||
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
|
||||
lastSyncedConversationModel = conversationModel;
|
||||
} else if (
|
||||
isRouter &&
|
||||
!modelsStore.selectedModelId &&
|
||||
modelsStore.loadedModelIds.length > 0 &&
|
||||
activeMessages().length > 0 &&
|
||||
!conversationModel
|
||||
) {
|
||||
lastSyncedConversationModel = null;
|
||||
// auto-select the first loaded model only when nothing is selected yet
|
||||
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
|
||||
if (first) modelsStore.selectModelById(first.id);
|
||||
|
|
@ -151,7 +163,7 @@
|
|||
<ModelsSelectorSheet
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
{currentModel}
|
||||
currentModel={selectorModel}
|
||||
{forceForegroundText}
|
||||
{useGlobalSelection}
|
||||
/>
|
||||
|
|
@ -159,7 +171,7 @@
|
|||
<ModelsSelectorDropdown
|
||||
disabled={disabled || isOffline}
|
||||
bind:this={selectorModelRef}
|
||||
{currentModel}
|
||||
currentModel={selectorModel}
|
||||
{forceForegroundText}
|
||||
{useGlobalSelection}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@
|
|||
return;
|
||||
}
|
||||
|
||||
if (import.meta.env.DEV) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log('[ChatFormPickerMcpPrompts] Fetching completions for:', {
|
||||
serverName: selectedPrompt.serverName,
|
||||
promptName: selectedPrompt.name,
|
||||
|
|
@ -181,7 +181,7 @@
|
|||
value
|
||||
);
|
||||
|
||||
if (import.meta.env.DEV) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log('[ChatFormPickerMcpPrompts] Autocomplete result:', {
|
||||
argName,
|
||||
value,
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
<script lang="ts">
|
||||
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import {
|
||||
ChatScreenForm,
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenActionScrollDown,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError,
|
||||
DialogChatError,
|
||||
ServerLoadingSplash,
|
||||
DialogConfirmation
|
||||
DialogConfirmation,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import { setProcessingInfoContext } from '$lib/contexts';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
|
|
@ -34,11 +34,12 @@
|
|||
activeConversation
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { serverLoading, serverError, serverStore, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { onMount } from 'svelte';
|
||||
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
|
||||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
|
|
@ -67,6 +68,8 @@
|
|||
|
||||
let showEmptyFileDialog = $state(false);
|
||||
|
||||
let processingInfoVisible = $state(false);
|
||||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
|
@ -174,6 +177,10 @@
|
|||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
function handleProcessingInfoVisibility(visible: boolean) {
|
||||
processingInfoVisible = visible;
|
||||
}
|
||||
|
||||
function handleDragEnter(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
|
|
@ -338,7 +345,9 @@
|
|||
});
|
||||
|
||||
function handleMessagesReady() {
|
||||
if (!disableAutoScroll && !autoScroll.userScrolledUp) {
|
||||
if (disableAutoScroll) return;
|
||||
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
requestAnimationFrame(() => {
|
||||
autoScroll.scrollToBottom('instant');
|
||||
});
|
||||
|
|
@ -392,59 +401,32 @@
|
|||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onMessagesReady={handleMessagesReady}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}}
|
||||
onMessagesReady={handleMessagesReady}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="pointer-events-none {isEmpty
|
||||
? 'absolute bottom-[calc(50dvh-7rem)]'
|
||||
: 'sticky bottom-4'} right-4 left-4 mt-auto pt-16 transition-all duration-200"
|
||||
class={[
|
||||
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
|
||||
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
|
||||
]}
|
||||
>
|
||||
{#if isEmpty}
|
||||
<div class="mb-8 px-4 text-center" use:fadeInView={{ duration: 300 }}>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
|
||||
<p class="text-muted-foreground md:text-lg">
|
||||
{serverStore.props?.modalities?.audio
|
||||
? 'Record audio, type a message '
|
||||
: 'Type a message'} or upload files to get started
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
<ChatScreenActionScrollDown
|
||||
container={chatScrollContainer}
|
||||
hasProcessingInfoVisible={processingInfoVisible}
|
||||
/>
|
||||
|
||||
{#if page.params.id}
|
||||
<ChatScreenProcessingInfo />
|
||||
{/if}
|
||||
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
|
||||
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
use:fadeInView={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
</div>
|
||||
{/if}
|
||||
<ChatScreenServerError />
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
<script lang="ts">
|
||||
import { ArrowDown } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
container: HTMLDivElement | undefined;
|
||||
hasProcessingInfoVisible: boolean;
|
||||
}
|
||||
|
||||
let { container, hasProcessingInfoVisible }: Props = $props();
|
||||
|
||||
let show = $state(false);
|
||||
|
||||
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
|
||||
|
||||
function checkVisibility() {
|
||||
if (!container) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
show = distanceFromBottom > clientHeight * 0.5;
|
||||
}
|
||||
|
||||
function scrollToBottom() {
|
||||
if (container) {
|
||||
container.scrollTo({
|
||||
top: container.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const c = container;
|
||||
if (c) {
|
||||
c.addEventListener('scroll', checkVisibility);
|
||||
checkVisibility();
|
||||
return () => {
|
||||
c.removeEventListener('scroll', checkVisibility);
|
||||
};
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="pointer-events-{show
|
||||
? 'auto'
|
||||
: 'none'} relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center"
|
||||
>
|
||||
<Button
|
||||
onclick={scrollToBottom}
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
class="pointer-events-all absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
|
||||
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
|
||||
? 1
|
||||
: 0};"
|
||||
aria-label="Scroll to bottom"
|
||||
>
|
||||
<ArrowDown class="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
isEmpty: boolean;
|
||||
}
|
||||
|
||||
let { isEmpty = false }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none mb-4 hidden px-4 text-center',
|
||||
isEmpty && 'pointer-events-auto block!'
|
||||
]}
|
||||
use:fadeInView={{ duration: 300 }}
|
||||
>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
|
||||
<p class="text-muted-foreground md:text-lg">
|
||||
{serverStore.props?.modalities?.audio ? 'Record audio, type a message ' : 'Type a message'} or upload
|
||||
files to get started
|
||||
</p>
|
||||
</div>
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getProcessingInfoContext } from '$lib/contexts';
|
||||
import { page } from '$app/state';
|
||||
|
||||
const processingState = useProcessingState();
|
||||
const processingInfoCtx = getProcessingInfoContext();
|
||||
|
|
@ -16,6 +17,14 @@
|
|||
let isStreaming = $derived(isChatStreaming());
|
||||
let processingDetails = $derived(processingState.getTechnicalDetails());
|
||||
|
||||
let processingVisible = $derived(processingDetails.length > 0);
|
||||
|
||||
let { onVisibilityChange }: { onVisibilityChange?: (visible: boolean) => void } = $props();
|
||||
|
||||
$effect(() => {
|
||||
onVisibilityChange?.(processingVisible);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const conversation = activeConversation();
|
||||
|
||||
|
|
@ -60,9 +69,12 @@
|
|||
</script>
|
||||
|
||||
<div
|
||||
class={['chat-processing-info-container pointer-events-none', showProcessingInfo && 'visible']}
|
||||
class={[
|
||||
'chat-processing-info-container pointer-events-none relative',
|
||||
page.params.id && showProcessingInfo && 'visible'
|
||||
]}
|
||||
>
|
||||
<div class="chat-processing-info-content">
|
||||
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
|
||||
{#each processingDetails as detail (detail)}
|
||||
<span class="chat-processing-info-detail pointer-events-auto backdrop-blur-sm">{detail}</span>
|
||||
{/each}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
<script lang="ts">
|
||||
import { AlertTriangle, RefreshCw } from '@lucide/svelte';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import { serverError, serverLoading, serverStore } from '$lib/stores/server.svelte';
|
||||
|
||||
let hasError = $derived(!!serverError());
|
||||
</script>
|
||||
|
||||
{#if hasError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
use:fadeInView={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={serverLoading()}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {serverLoading() ? 'animate-spin' : ''}" />
|
||||
{serverLoading() ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
</div>
|
||||
{/if}
|
||||
|
|
@ -667,3 +667,17 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
|
|||
* Only visible when `isCurrentConversationLoading` is true.
|
||||
*/
|
||||
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
|
||||
/**
|
||||
* Scroll-to-bottom action button. Displays a floating button when the user
|
||||
* has scrolled up more than half a viewport height from the bottom.
|
||||
* Takes the chat container element as a prop to manage scroll state internally.
|
||||
*/
|
||||
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
|
||||
|
||||
/**
|
||||
* Server error alert displayed when the server is unreachable.
|
||||
* Shows the error message with a retry button.
|
||||
* Rendered inside ChatScreen when `serverError` store has a value.
|
||||
*/
|
||||
export { default as ChatScreenServerError } from './ChatScreen/ChatScreenServerError.svelte';
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@
|
|||
SETTINGS_KEYS
|
||||
} from '$lib/constants';
|
||||
import { ColorMode, UrlProtocol } from '$lib/enums';
|
||||
import { FileTypeText } from '$lib/enums/files';
|
||||
import { FileTypeText } from '$lib/enums/files.enums';
|
||||
import { highlightCode, detectIncompleteCodeBlock, type IncompleteCodeBlock } from '$lib/utils';
|
||||
import '$styles/katex-custom.scss';
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
} from '$lib/constants';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import { ColorMode } from '$lib/enums/ui';
|
||||
import { ColorMode } from '$lib/enums/ui.enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
import * as Select from '$lib/components/ui/select';
|
||||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_INFO, SETTINGS_KEYS } from '$lib/constants';
|
||||
import { SettingsFieldType } from '$lib/enums/settings';
|
||||
import { SettingsFieldType } from '$lib/enums/settings.enums';
|
||||
import { settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, selectedModelName, propsCacheVersion } from '$lib/stores/models.svelte';
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { Zap, Globe, Radio } from '@lucide/svelte';
|
|||
import { MCPTransportType } from '$lib/enums';
|
||||
import type { ClientCapabilities, Implementation } from '$lib/types';
|
||||
import type { Component } from 'svelte';
|
||||
import { MimeTypeImage } from '$lib/enums/files';
|
||||
import { MimeTypeImage } from '$lib/enums/files.enums';
|
||||
|
||||
export const DEFAULT_CLIENT_VERSION = '1.0.0';
|
||||
export const MCP_CLIENT_NAME = 'llama-ui-mcp';
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { ColorMode } from '$lib/enums/ui';
|
||||
import { SettingsFieldType } from '$lib/enums/settings';
|
||||
import { ColorMode } from '$lib/enums/ui.enums';
|
||||
import { SettingsFieldType } from '$lib/enums/settings.enums';
|
||||
import { SyncableParameterType } from '$lib/enums';
|
||||
import {
|
||||
Funnel,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import {
|
|||
MimeTypeApplication,
|
||||
MimeTypeText
|
||||
} from '$lib/enums';
|
||||
import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files';
|
||||
import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files.enums';
|
||||
|
||||
// File type configuration using enums
|
||||
export const AUDIO_FILE_TYPES = {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { ToolSource } from '$lib/enums/tools';
|
||||
import { ToolSource } from '$lib/enums/tools.enums';
|
||||
|
||||
export const TOOL_GROUP_LABELS = {
|
||||
[ToolSource.BUILTIN]: 'Built-in',
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ export {
|
|||
AttachmentItemEnabledWhen,
|
||||
AttachmentAction,
|
||||
AttachmentItemVisibleWhen
|
||||
} from './attachment';
|
||||
} from './attachment.enums';
|
||||
|
||||
export { AgenticSectionType, ToolCallType } from './agentic';
|
||||
export { AgenticSectionType, ToolCallType } from './agentic.enums';
|
||||
|
||||
export {
|
||||
ChatMessageStatsView,
|
||||
|
|
@ -17,7 +17,7 @@ export {
|
|||
MessageType,
|
||||
PdfViewMode,
|
||||
ReasoningFormat
|
||||
} from './chat';
|
||||
} from './chat.enums';
|
||||
|
||||
export {
|
||||
FileTypeCategory,
|
||||
|
|
@ -38,7 +38,7 @@ export {
|
|||
MimeTypeImage,
|
||||
MimeTypeText,
|
||||
SpecialFileType
|
||||
} from './files';
|
||||
} from './files.enums';
|
||||
|
||||
export {
|
||||
MCPConnectionPhase,
|
||||
|
|
@ -48,16 +48,16 @@ export {
|
|||
MCPContentType,
|
||||
MCPRefType,
|
||||
JsonSchemaType
|
||||
} from './mcp';
|
||||
} from './mcp.enums';
|
||||
|
||||
export { ModelModality } from './model';
|
||||
export { ModelModality } from './model.enums';
|
||||
|
||||
export { ServerRole, ServerModelStatus } from './server';
|
||||
export { ServerRole, ServerModelStatus } from './server.enums';
|
||||
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings';
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
|
||||
|
||||
export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui';
|
||||
export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui.enums';
|
||||
|
||||
export { KeyboardKey } from './keyboard';
|
||||
export { KeyboardKey } from './keyboard.enums';
|
||||
|
||||
export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools';
|
||||
export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools.enums';
|
||||
|
|
|
|||
|
|
@ -100,6 +100,14 @@ export class AutoScrollController {
|
|||
this._autoScrollEnabled = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets scroll state when switching conversations.
|
||||
*/
|
||||
resetScrollState(): void {
|
||||
this._userScrolledUp = false;
|
||||
this._autoScrollEnabled = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts the auto-scroll interval for continuous scrolling during streaming.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
const serverModel = $derived(singleModelName());
|
||||
|
||||
const currentModel = $derived(opts.currentModel());
|
||||
const useGlobalSelection = $derived(opts.useGlobalSelection?.() ?? false);
|
||||
const onModelChange = $derived(opts.onModelChange?.());
|
||||
|
||||
const isHighlightedCurrentModelActive = $derived.by(() => {
|
||||
|
|
@ -128,6 +127,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
|
||||
if (onModelChange) {
|
||||
const result = await onModelChange(option.id, option.model);
|
||||
|
||||
if (result === false) {
|
||||
shouldCloseMenu = false;
|
||||
}
|
||||
|
|
@ -142,12 +142,14 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
const textarea = document.querySelector<HTMLTextAreaElement>(
|
||||
'[data-slot="chat-form"] textarea'
|
||||
);
|
||||
|
||||
textarea?.focus();
|
||||
});
|
||||
}
|
||||
|
||||
if (!onModelChange && isRouter && !modelsStore.isModelLoaded(option.model)) {
|
||||
isLoadingModel = true;
|
||||
|
||||
modelsStore
|
||||
.loadModel(option.model)
|
||||
.catch((error) => console.error('Failed to load model:', error))
|
||||
|
|
@ -158,6 +160,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
function getDisplayOption(): ModelOption | undefined {
|
||||
if (!isRouter) {
|
||||
const displayModel = serverModel || currentModel;
|
||||
|
||||
if (displayModel) {
|
||||
return {
|
||||
id: serverModel ? 'current' : 'offline-current',
|
||||
|
|
@ -166,12 +169,8 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
capabilities: []
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (useGlobalSelection && activeId) {
|
||||
const selected = options.find((option) => option.id === activeId);
|
||||
if (selected) return selected;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (currentModel) {
|
||||
|
|
@ -183,6 +182,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
capabilities: []
|
||||
};
|
||||
}
|
||||
|
||||
return options.find((option) => option.model === currentModel);
|
||||
}
|
||||
|
||||
|
|
@ -197,57 +197,77 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
|
|||
get options() {
|
||||
return options;
|
||||
},
|
||||
|
||||
get loading() {
|
||||
return loading;
|
||||
},
|
||||
|
||||
get updating() {
|
||||
return updating;
|
||||
},
|
||||
|
||||
get activeId() {
|
||||
return activeId;
|
||||
},
|
||||
|
||||
get isRouter() {
|
||||
return isRouter;
|
||||
},
|
||||
|
||||
get serverModel() {
|
||||
return serverModel;
|
||||
},
|
||||
|
||||
get isHighlightedCurrentModelActive() {
|
||||
return isHighlightedCurrentModelActive;
|
||||
},
|
||||
|
||||
get isCurrentModelInCache() {
|
||||
return isCurrentModelInCache;
|
||||
},
|
||||
|
||||
get filteredOptions() {
|
||||
return filteredOptions;
|
||||
},
|
||||
|
||||
get groupedFilteredOptions() {
|
||||
return groupedFilteredOptions;
|
||||
},
|
||||
|
||||
get isLoadingModel() {
|
||||
return isLoadingModel;
|
||||
},
|
||||
|
||||
get searchTerm() {
|
||||
return searchTerm;
|
||||
},
|
||||
|
||||
get showModelDialog() {
|
||||
return showModelDialog;
|
||||
},
|
||||
|
||||
get infoModelId() {
|
||||
return infoModelId;
|
||||
},
|
||||
|
||||
setSearchTerm(value: string) {
|
||||
searchTerm = value;
|
||||
},
|
||||
|
||||
setShowModelDialog(value: boolean) {
|
||||
showModelDialog = value;
|
||||
},
|
||||
|
||||
handleInfoClick,
|
||||
|
||||
handleSelect,
|
||||
|
||||
handleOpenChange,
|
||||
|
||||
isFavorite(model: string) {
|
||||
return modelsStore.favoriteModelIds.has(model);
|
||||
},
|
||||
|
||||
getDisplayOption
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -392,7 +392,7 @@ export class MCPService {
|
|||
|
||||
const url = new URL(config.url);
|
||||
|
||||
if (import.meta.env.DEV) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService] Creating WebSocket transport for ${url.href}`);
|
||||
}
|
||||
|
||||
|
|
@ -413,12 +413,12 @@ export class MCPService {
|
|||
onLog
|
||||
);
|
||||
|
||||
if (useProxy && import.meta.env.DEV) {
|
||||
if (useProxy && import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
|
||||
}
|
||||
|
||||
try {
|
||||
if (import.meta.env.DEV) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService] Creating StreamableHTTP transport for ${url.href}`);
|
||||
}
|
||||
|
||||
|
|
@ -520,7 +520,7 @@ export class MCPService {
|
|||
)
|
||||
);
|
||||
|
||||
if (import.meta.env.DEV) {
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService][${serverName}] Creating transport...`);
|
||||
}
|
||||
|
||||
|
|
@ -560,6 +560,22 @@ export class MCPService {
|
|||
);
|
||||
|
||||
const runtimeErrorHandler = (error: Error) => {
|
||||
// Ignore errors that are expected when the SDK's transport is closed,
|
||||
// or when connecting to servers that don't support SSE (stateless-only
|
||||
// endpoints returning 405). The SDK wraps the original AbortError in
|
||||
// a new Error with the message "SSE stream disconnected: AbortError",
|
||||
// and also produces "Cannot cancel a stream locked by a reader".
|
||||
// DOMException is thrown by the browser when aborting fetch requests.
|
||||
const msg = error.message || String(error);
|
||||
if (
|
||||
error.name === 'AbortError' ||
|
||||
error instanceof DOMException ||
|
||||
msg.includes('SSE stream disconnected') ||
|
||||
msg.includes('stream locked by a reader') ||
|
||||
msg.includes('The operation was aborted')
|
||||
) {
|
||||
return;
|
||||
}
|
||||
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
|
||||
};
|
||||
|
||||
|
|
@ -658,7 +674,10 @@ export class MCPService {
|
|||
this.createLog(MCPConnectionPhase.LISTING_TOOLS, 'Listing available tools...')
|
||||
);
|
||||
|
||||
console.log(`[MCPService][${serverName}] Connected, listing tools...`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService][${serverName}] Connected, listing tools...`);
|
||||
}
|
||||
|
||||
const tools = await this.listTools({
|
||||
client,
|
||||
transport,
|
||||
|
|
@ -680,10 +699,11 @@ export class MCPService {
|
|||
`Connection established with ${tools.length} tools (${connectionTimeMs}ms)`
|
||||
)
|
||||
);
|
||||
|
||||
console.log(
|
||||
`[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms`
|
||||
);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(
|
||||
`[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms`
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
client,
|
||||
|
|
@ -709,9 +729,22 @@ export class MCPService {
|
|||
* @param connection - The active MCP connection to close
|
||||
*/
|
||||
static async disconnect(connection: MCPConnection): Promise<void> {
|
||||
console.log(`[MCPService][${connection.serverName}] Disconnecting...`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService][${connection.serverName}] Disconnecting...`);
|
||||
}
|
||||
|
||||
try {
|
||||
// Prevent reconnection on voluntary disconnect
|
||||
// Terminate the session first for streamable-http transports to cleanly
|
||||
// close streams, matching the inspector's disconnect flow.
|
||||
if (connection.transport instanceof StreamableHTTPClientTransport) {
|
||||
await connection.transport.terminateSession();
|
||||
}
|
||||
|
||||
// Clear error handlers before closing to prevent noise from expected
|
||||
// abort errors during shutdown. The inspector avoids this entirely
|
||||
// by not setting onerror, but since we use it for protocol logging,
|
||||
// we must clear it before disconnect.
|
||||
connection.client.onerror = undefined;
|
||||
if (connection.transport.onclose) {
|
||||
connection.transport.onclose = undefined;
|
||||
}
|
||||
|
|
@ -1078,7 +1111,9 @@ export class MCPService {
|
|||
try {
|
||||
await connection.client.unsubscribeResource({ uri });
|
||||
|
||||
console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`[MCPService][${connection.serverName}] Failed to unsubscribe from resource:`,
|
||||
|
|
|
|||
|
|
@ -119,7 +119,8 @@ const localStorageMigration: Migration = {
|
|||
// Only migrate if new key doesn't already exist
|
||||
const newValue = localStorage.getItem(newKey);
|
||||
if (newValue !== null) {
|
||||
console.log(`[Migration] localStorage: ${newKey} already exists, skipping`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] localStorage: ${newKey} already exists, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -127,9 +128,11 @@ const localStorageMigration: Migration = {
|
|||
if (oldValue !== null) {
|
||||
localStorage.setItem(newKey, oldValue);
|
||||
// Keep old key for downgrade compatibility - DO NOT DELETE
|
||||
console.log(
|
||||
`[Migration] localStorage: copied ${deprecatedKey} → ${newKey} (preserved old)`
|
||||
);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
console.log(
|
||||
`[Migration] localStorage: copied ${deprecatedKey} → ${newKey} (preserved old)`
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -146,7 +149,8 @@ const idxdbMigration: Migration = {
|
|||
async run(): Promise<void> {
|
||||
const oldDbNames = await Dexie.getDatabaseNames();
|
||||
if (!oldDbNames.includes(DB_APP_NAME_DEPRECATED)) {
|
||||
console.log('[Migration] IndexedDB: no old database found, skipping');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] IndexedDB: no old database found, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -155,11 +159,13 @@ const idxdbMigration: Migration = {
|
|||
newDb.version(1).stores(IDXDB_STORES);
|
||||
const existingConvs = await newDb.table(IDXDB_TABLES.conversations).count();
|
||||
if (existingConvs > 0) {
|
||||
console.log('[Migration] IndexedDB: new database already has data, skipping');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] IndexedDB: new database already has data, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED);
|
||||
|
||||
const oldDb = new Dexie(DB_APP_NAME_DEPRECATED);
|
||||
oldDb.version(1).stores(IDXDB_STORES);
|
||||
|
|
@ -169,15 +175,18 @@ const idxdbMigration: Migration = {
|
|||
|
||||
if (conversations.length > 0) {
|
||||
await newDb.table(IDXDB_TABLES.conversations).bulkAdd(conversations);
|
||||
console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`);
|
||||
}
|
||||
if (messages.length > 0) {
|
||||
await newDb.table(IDXDB_TABLES.messages).bulkAdd(messages);
|
||||
console.log(`[Migration] IndexedDB: copied ${messages.length} messages`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] IndexedDB: copied ${messages.length} messages`);
|
||||
}
|
||||
|
||||
// Non-destructive: DO NOT delete old database - keep for downgrade compatibility
|
||||
console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility');
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -419,7 +428,8 @@ const legacyMessageMigration: Migration = {
|
|||
}
|
||||
}
|
||||
|
||||
console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -434,7 +444,8 @@ const themeMigration: Migration = {
|
|||
async run(): Promise<void> {
|
||||
const legacyTheme = localStorage.getItem('theme');
|
||||
if (legacyTheme === null) {
|
||||
console.log('[Migration] Theme: no legacy theme key found, skipping');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] Theme: no legacy theme key found, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -443,7 +454,8 @@ const themeMigration: Migration = {
|
|||
const config = configRaw ? JSON.parse(configRaw) : {};
|
||||
|
||||
if (SETTINGS_KEYS.THEME in config) {
|
||||
console.log('[Migration] Theme: config already has theme, skipping');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] Theme: config already has theme, skipping');
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -451,7 +463,8 @@ const themeMigration: Migration = {
|
|||
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
|
||||
|
||||
// Non-destructive: DO NOT delete legacy theme key - keep for downgrade compatibility
|
||||
console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -491,7 +504,8 @@ export const MigrationService = {
|
|||
*/
|
||||
resetState(): void {
|
||||
localStorage.removeItem(MIGRATION_STATE_KEY);
|
||||
console.log('[Migration] State reset - all migrations will run again');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] State reset - all migrations will run again');
|
||||
},
|
||||
|
||||
/**
|
||||
|
|
@ -500,25 +514,30 @@ export const MigrationService = {
|
|||
*/
|
||||
async runAllMigrations(): Promise<void> {
|
||||
const state = getMigrationState();
|
||||
console.log('[Migration] Starting migration run, state:', state);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] Starting migration run, state:', state);
|
||||
|
||||
for (const migration of migrations) {
|
||||
if (isMigrationCompleted(migration.id)) {
|
||||
console.log(`[Migration] ${migration.id}: already completed, skipping`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] ${migration.id}: already completed, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
console.log(`[Migration] ${migration.id}: running...`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] ${migration.id}: running...`);
|
||||
await migration.run();
|
||||
markMigrationCompleted(migration.id);
|
||||
console.log(`[Migration] ${migration.id}: completed successfully`);
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log(`[Migration] ${migration.id}: completed successfully`);
|
||||
} catch (error) {
|
||||
console.error(`[Migration] ${migration.id}: failed`, error);
|
||||
markMigrationFailed(migration.id);
|
||||
}
|
||||
}
|
||||
|
||||
console.log('[Migration] All migrations complete');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] All migrations complete');
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@
|
|||
*/
|
||||
|
||||
import { browser } from '$app/environment';
|
||||
import { base } from '$app/paths';
|
||||
import { SETTINGS_KEYS } from '$lib/constants';
|
||||
import { MCPService } from '$lib/services/mcp.service';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { mode } from 'mode-watcher';
|
||||
import {
|
||||
parseMcpServerSettings,
|
||||
|
|
@ -43,7 +43,6 @@ import {
|
|||
ToolCallType
|
||||
} from '$lib/enums';
|
||||
import {
|
||||
CORS_PROXY_ENDPOINT,
|
||||
DEFAULT_CACHE_TTL_MS,
|
||||
DEFAULT_MCP_CONFIG,
|
||||
EXPECTED_THEMED_ICON_PAIR_COUNT,
|
||||
|
|
@ -86,7 +85,6 @@ class MCPStore {
|
|||
private _toolCount = $state(0);
|
||||
private _connectedServers = $state<string[]>([]);
|
||||
private _healthChecks = $state<Record<string, HealthCheckState>>({});
|
||||
private _proxyAvailable = $state(false);
|
||||
|
||||
private connections = new Map<string, MCPConnection>();
|
||||
private toolsIndex = new Map<string, string>();
|
||||
|
|
@ -96,27 +94,8 @@ class MCPStore {
|
|||
private initPromise: Promise<boolean> | null = null;
|
||||
private activeFlowCount = 0;
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
this.probeProxy();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Probes the CORS proxy endpoint to determine availability.
|
||||
* The endpoint is only registered when llama-server runs with --ui-mcp-proxy.
|
||||
*/
|
||||
async probeProxy(): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`${base}${CORS_PROXY_ENDPOINT}`, { method: 'HEAD' });
|
||||
this._proxyAvailable = response.status !== 404;
|
||||
} catch {
|
||||
this._proxyAvailable = false;
|
||||
}
|
||||
}
|
||||
|
||||
get isProxyAvailable(): boolean {
|
||||
return this._proxyAvailable;
|
||||
return serverStore.props?.cors_proxy_enabled ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { toast } from 'svelte-sonner';
|
|||
import { ServerModelStatus, ModelModality } from '$lib/enums';
|
||||
import { ModelsService } from '$lib/services/models.service';
|
||||
import { PropsService } from '$lib/services/props.service';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { TTLCache } from '$lib/utils';
|
||||
import {
|
||||
MODEL_PROPS_CACHE_TTL_MS,
|
||||
|
|
@ -14,14 +14,7 @@ import {
|
|||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
|
||||
/**
|
||||
* modelsStore - Reactive store for model management in both MODEL and ROUTER modes
|
||||
*
|
||||
* This store manages:
|
||||
* - Available models list
|
||||
* - Selected model for new conversations
|
||||
* - Loaded models tracking (ROUTER mode)
|
||||
* - Model usage tracking per conversation
|
||||
* - Automatic unloading of unused models
|
||||
* modelsStore - Reactive store for model management in both MODEL and ROUTER modes.
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **ModelsService**: Stateless service for model API communication
|
||||
|
|
@ -31,14 +24,8 @@ import { conversationsStore } from '$lib/stores/conversations.svelte';
|
|||
*
|
||||
* **API Inconsistency Workaround:**
|
||||
* In MODEL mode, `/props` returns modalities for the single model.
|
||||
* In ROUTER mode, `/props` has no modalities - must use `/props?model=<id>` per model.
|
||||
* In ROUTER mode, `/props` has no modalities — must use `/props?model=<id>` per model.
|
||||
* This store normalizes this behavior so consumers don't need to know the server mode.
|
||||
*
|
||||
* **Key Features:**
|
||||
* - **MODEL mode**: Single model, always loaded
|
||||
* - **ROUTER mode**: Multi-model with load/unload capability
|
||||
* - **Auto-unload**: Automatically unloads models not used by any conversation
|
||||
* - **Lazy loading**: ensureModelLoaded() loads models on demand
|
||||
*/
|
||||
class ModelsStore {
|
||||
/**
|
||||
|
|
@ -57,8 +44,8 @@ class ModelsStore {
|
|||
selectedModelId = $state<string | null>(null);
|
||||
selectedModelName = $state<string | null>(null);
|
||||
|
||||
// dedup concurrent fetch() callers, all awaiters share the same inflight promise
|
||||
// without this, ?model=<name> URL handler raced an in-progress fetch and saw an empty list
|
||||
// Dedup concurrent fetch() callers — all awaiters share the same inflight promise.
|
||||
// Without this, ?model=<name> URL handler races an in-progress fetch and sees an empty list.
|
||||
private inflightFetch: Promise<void> | null = null;
|
||||
|
||||
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
|
||||
|
|
@ -67,9 +54,9 @@ class ModelsStore {
|
|||
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
|
||||
|
||||
/**
|
||||
* Model-specific props cache with TTL
|
||||
* Key: modelId, Value: props data including modalities
|
||||
* TTL: 10 minutes - props don't change frequently
|
||||
* Model-specific props cache with TTL.
|
||||
* Key: modelId, Value: props data including modalities.
|
||||
* TTL: 10 minutes — props don't change frequently.
|
||||
*/
|
||||
private modelPropsCache = new TTLCache<string, ApiLlamaCppServerProps>({
|
||||
ttlMs: MODEL_PROPS_CACHE_TTL_MS,
|
||||
|
|
@ -78,7 +65,7 @@ class ModelsStore {
|
|||
private modelPropsFetching = $state<Set<string>>(new Set());
|
||||
|
||||
/**
|
||||
* Version counter for props cache - used to trigger reactivity when props are updated
|
||||
* Version counter for props cache — used to trigger reactivity when props are updated.
|
||||
*/
|
||||
propsCacheVersion = $state(0);
|
||||
|
||||
|
|
@ -92,7 +79,7 @@ class ModelsStore {
|
|||
|
||||
get selectedModel(): ModelOption | null {
|
||||
if (!this.selectedModelId) return null;
|
||||
return this.models.find((model) => model.id === this.selectedModelId) ?? null;
|
||||
return this.models.find((m) => m.id === this.selectedModelId) ?? null;
|
||||
}
|
||||
|
||||
get loadedModelIds(): string[] {
|
||||
|
|
@ -117,7 +104,7 @@ class ModelsStore {
|
|||
* In ROUTER mode, returns null (model is per-conversation).
|
||||
*/
|
||||
get singleModelName(): string | null {
|
||||
if (serverStore.isRouterMode) return null;
|
||||
if (isRouterMode()) return null;
|
||||
|
||||
const props = serverStore.props;
|
||||
if (props?.model_alias) return props.model_alias;
|
||||
|
|
@ -126,6 +113,11 @@ class ModelsStore {
|
|||
return props.model_path.split(/(\\|\/)/).pop() || null;
|
||||
}
|
||||
|
||||
get selectedModelContextSize(): number | null {
|
||||
if (!this.selectedModelName) return null;
|
||||
return this.getModelContextSize(this.selectedModelName);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
|
|
@ -134,10 +126,6 @@ class ModelsStore {
|
|||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Get modalities for a specific model
|
||||
* Returns cached modalities from model props
|
||||
*/
|
||||
getModelModalities(modelId: string): ModelModalities | null {
|
||||
const model = this.models.find((m) => m.model === modelId || m.id === modelId);
|
||||
if (model?.modalities) {
|
||||
|
|
@ -146,46 +134,29 @@ class ModelsStore {
|
|||
|
||||
const props = this.modelPropsCache.get(modelId);
|
||||
if (props?.modalities) {
|
||||
return {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
return this.buildModalities(props.modalities);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports vision modality
|
||||
*/
|
||||
modelSupportsVision(modelId: string): boolean {
|
||||
return this.getModelModalities(modelId)?.vision ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports audio modality
|
||||
*/
|
||||
modelSupportsAudio(modelId: string): boolean {
|
||||
return this.getModelModalities(modelId)?.audio ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports video modality
|
||||
*/
|
||||
modelSupportsVideo(modelId: string): boolean {
|
||||
return this.getModelModalities(modelId)?.video ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model modalities as an array of ModelModality enum values
|
||||
*/
|
||||
getModelModalitiesArray(modelId: string): ModelModality[] {
|
||||
const modalities = this.getModelModalities(modelId);
|
||||
if (!modalities) return [];
|
||||
|
||||
const result: ModelModality[] = [];
|
||||
|
||||
if (modalities.vision) result.push(ModelModality.VISION);
|
||||
if (modalities.audio) result.push(ModelModality.AUDIO);
|
||||
if (modalities.video) result.push(ModelModality.VIDEO);
|
||||
|
|
@ -193,16 +164,10 @@ class ModelsStore {
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get props for a specific model (from cache)
|
||||
*/
|
||||
getModelProps(modelId: string): ApiLlamaCppServerProps | null {
|
||||
return this.modelPropsCache.get(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size (n_ctx) for a specific model from cached props
|
||||
*/
|
||||
getModelContextSize(modelId: string): number | null {
|
||||
const props = this.getModelProps(modelId);
|
||||
const nCtx = props?.default_generation_settings?.n_ctx;
|
||||
|
|
@ -210,17 +175,6 @@ class ModelsStore {
|
|||
return typeof nCtx === 'number' ? nCtx : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get context size for the currently selected model or null if no model is selected
|
||||
*/
|
||||
get selectedModelContextSize(): number | null {
|
||||
if (!this.selectedModelName) return null;
|
||||
return this.getModelContextSize(this.selectedModelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if props are being fetched for a model
|
||||
*/
|
||||
isModelPropsFetching(modelId: string): boolean {
|
||||
return this.modelPropsFetching.has(modelId);
|
||||
}
|
||||
|
|
@ -235,10 +189,10 @@ class ModelsStore {
|
|||
|
||||
isModelLoaded(modelId: string): boolean {
|
||||
const model = this.routerModels.find((m) => m.id === modelId);
|
||||
|
||||
return (
|
||||
model?.status.value === ServerModelStatus.LOADED ||
|
||||
model?.status.value === ServerModelStatus.SLEEPING ||
|
||||
false
|
||||
model?.status.value === ServerModelStatus.SLEEPING
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -248,6 +202,7 @@ class ModelsStore {
|
|||
|
||||
getModelStatus(modelId: string): ServerModelStatus | null {
|
||||
const model = this.routerModels.find((m) => m.id === modelId);
|
||||
|
||||
return model?.status.value ?? null;
|
||||
}
|
||||
|
||||
|
|
@ -257,6 +212,7 @@ class ModelsStore {
|
|||
|
||||
isModelInUse(modelId: string): boolean {
|
||||
const usage = this.modelUsage.get(modelId);
|
||||
|
||||
return usage !== undefined && usage.size > 0;
|
||||
}
|
||||
|
||||
|
|
@ -269,8 +225,8 @@ class ModelsStore {
|
|||
*/
|
||||
|
||||
/**
|
||||
* Fetch list of models from server and detect server role
|
||||
* Also fetches modalities for MODEL mode (single model)
|
||||
* Fetch list of models from server and detect server role.
|
||||
* Also fetches modalities for MODEL mode (single model).
|
||||
*/
|
||||
async fetch(force = false): Promise<void> {
|
||||
if (this.inflightFetch) return this.inflightFetch;
|
||||
|
|
@ -293,69 +249,87 @@ class ModelsStore {
|
|||
await serverStore.fetch();
|
||||
}
|
||||
|
||||
const response = await ModelsService.list();
|
||||
const router = isRouterMode();
|
||||
|
||||
const models: ModelOption[] = response.data.map((item: ApiModelDataEntry, index: number) => {
|
||||
const details = response.models?.[index];
|
||||
const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : [];
|
||||
const displayNameSource =
|
||||
details?.name && details.name.trim().length > 0 ? details.name : item.id;
|
||||
const displayName = this.toDisplayName(displayNameSource);
|
||||
const modelId = details?.model || item.id;
|
||||
if (router) {
|
||||
const response = await ModelsService.listRouter();
|
||||
|
||||
return {
|
||||
id: item.id,
|
||||
name: displayName,
|
||||
model: modelId,
|
||||
description: details?.description,
|
||||
capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)),
|
||||
details: details?.details,
|
||||
meta: item.meta ?? null,
|
||||
parsedId: ModelsService.parseModelId(modelId),
|
||||
aliases: item.aliases ?? [],
|
||||
tags: item.tags ?? []
|
||||
} satisfies ModelOption;
|
||||
});
|
||||
this.routerModels = response.data;
|
||||
this.models = this.buildModelOptions(response);
|
||||
|
||||
this.models = models;
|
||||
await this.fetchModalitiesForLoadedModels();
|
||||
|
||||
// WORKAROUND: In MODEL mode, /props returns modalities for the single model,
|
||||
// but /v1/models doesn't include modalities. We bridge this gap here.
|
||||
const serverProps = serverStore.props;
|
||||
if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) {
|
||||
const modalities: ModelModalities = {
|
||||
vision: serverProps.modalities.vision ?? false,
|
||||
audio: serverProps.modalities.audio ?? false,
|
||||
video: serverProps.modalities.video ?? false
|
||||
};
|
||||
this.modelPropsCache.set(this.models[0].model, serverProps);
|
||||
this.models = this.models.map((model, index) =>
|
||||
index === 0 ? { ...model, modalities } : model
|
||||
);
|
||||
const visible = this.getVisibleModels();
|
||||
|
||||
if (visible.length === 1 && this.isModelLoaded(visible[0].model)) {
|
||||
this.selectModelById(visible[0].id);
|
||||
}
|
||||
} else {
|
||||
this.models = await this.fetchModelModeInternal();
|
||||
}
|
||||
} catch (error) {
|
||||
this.models = [];
|
||||
this.error = error instanceof Error ? error.message : 'Failed to load models';
|
||||
|
||||
throw error;
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
/** Fetch models in MODEL mode (single model, standard OpenAI-compatible). */
|
||||
private async fetchModelModeInternal(): Promise<ModelOption[]> {
|
||||
const response = await ModelsService.list();
|
||||
|
||||
return this.buildModelOptions(response);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch router models with full metadata (ROUTER mode only)
|
||||
* This fetches the /models endpoint which returns status info for each model
|
||||
* Build ModelOption[] from an API response.
|
||||
* Both MODEL and ROUTER modes share the same mapping logic;
|
||||
* they differ only in which endpoint is called.
|
||||
*/
|
||||
private buildModelOptions(
|
||||
response: ApiModelListResponse | ApiRouterModelsListResponse
|
||||
): ModelOption[] {
|
||||
return response.data.map((item: ApiModelDataEntry, index: number) => {
|
||||
const details = response.models?.[index];
|
||||
const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : [];
|
||||
const displayNameSource =
|
||||
details?.name && details.name.trim().length > 0 ? details.name : item.id;
|
||||
const modelId = details?.model || item.id;
|
||||
|
||||
return {
|
||||
id: item.id,
|
||||
name: this.toDisplayName(displayNameSource),
|
||||
model: modelId,
|
||||
description: details?.description,
|
||||
capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)),
|
||||
details: details?.details,
|
||||
meta: item.meta ?? null,
|
||||
parsedId: ModelsService.parseModelId(modelId),
|
||||
aliases: item.aliases ?? [],
|
||||
tags: item.tags ?? []
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch router models with full metadata (ROUTER mode only).
|
||||
* No-op in router mode — fetch() already calls listRouter() internally.
|
||||
* Kept for API compatibility (e.g. handleOpenChange dropdown open handler).
|
||||
*/
|
||||
async fetchRouterModels(): Promise<void> {
|
||||
if (!isRouterMode()) return;
|
||||
|
||||
try {
|
||||
const response = await ModelsService.listRouter();
|
||||
this.routerModels = response.data;
|
||||
await this.fetchModalitiesForLoadedModels();
|
||||
|
||||
const o = this.models.filter((option) => this.getModelProps(option.model)?.ui !== false);
|
||||
|
||||
if (o.length === 1 && this.isModelLoaded(o[0].model)) {
|
||||
this.selectModelById(o[0].id);
|
||||
const visible = this.getVisibleModels();
|
||||
if (visible.length === 1 && this.isModelLoaded(visible[0].model)) {
|
||||
this.selectModelById(visible[0].id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to fetch router models:', error);
|
||||
|
|
@ -364,10 +338,10 @@ class ModelsStore {
|
|||
}
|
||||
|
||||
/**
|
||||
* Fetch props for a specific model from /props endpoint
|
||||
* Uses caching to avoid redundant requests
|
||||
* Fetch props for a specific model from /props endpoint.
|
||||
* Uses caching to avoid redundant requests.
|
||||
*
|
||||
* In ROUTER mode, this will only fetch props if the model is loaded,
|
||||
* In ROUTER mode, this only fetches props if the model is loaded,
|
||||
* since unloaded models return 400 from /props endpoint.
|
||||
*
|
||||
* @param modelId - Model identifier to fetch props for
|
||||
|
|
@ -397,10 +371,7 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch modalities for all loaded models from /props endpoint
|
||||
* This updates the modalities field in models array
|
||||
*/
|
||||
/** Fetch modalities for all loaded models from /props endpoint. */
|
||||
async fetchModalitiesForLoadedModels(): Promise<void> {
|
||||
const loadedModelIds = this.loadedModelIds;
|
||||
if (loadedModelIds.length === 0) return;
|
||||
|
|
@ -410,7 +381,6 @@ class ModelsStore {
|
|||
try {
|
||||
const results = await Promise.all(propsPromises);
|
||||
|
||||
// Update models with modalities
|
||||
this.models = this.models.map((model) => {
|
||||
const modelIndex = loadedModelIds.indexOf(model.model);
|
||||
if (modelIndex === -1) return model;
|
||||
|
|
@ -418,13 +388,7 @@ class ModelsStore {
|
|||
const props = results[modelIndex];
|
||||
if (!props?.modalities) return model;
|
||||
|
||||
const modalities: ModelModalities = {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
|
||||
return { ...model, modalities };
|
||||
return { ...model, modalities: this.buildModalities(props.modalities) };
|
||||
});
|
||||
|
||||
this.propsCacheVersion++;
|
||||
|
|
@ -433,17 +397,38 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update modalities for a specific model.
|
||||
* Called when a model is loaded or when we need fresh modality data.
|
||||
*/
|
||||
async updateModelModalities(modelId: string): Promise<void> {
|
||||
const props = await this.fetchModelProps(modelId);
|
||||
if (!props?.modalities) return;
|
||||
|
||||
this.models = this.models.map((model) =>
|
||||
model.model === modelId
|
||||
? { ...model, modalities: this.buildModalities(props.modalities!) }
|
||||
: model
|
||||
);
|
||||
|
||||
this.propsCacheVersion++;
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter to models visible in the UI (ui !== false).
|
||||
*/
|
||||
private getVisibleModels(): ModelOption[] {
|
||||
return this.models.filter((option) => this.getModelProps(option.model)?.ui !== false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model name from the last assistant message in the active conversation.
|
||||
* Iterates backward through messages to find the most recent message with a model.
|
||||
* Used by both the chat page and settings page to maintain model consistency.
|
||||
* @returns The model name or null if not found
|
||||
*/
|
||||
getModelFromLastAssistantResponse(): string | null {
|
||||
const messages = conversationsStore.activeMessages;
|
||||
if (!messages || messages.length === 0) return null;
|
||||
|
||||
// Iterate backward to find the last message with a model
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].model) {
|
||||
return messages[i].model;
|
||||
|
|
@ -456,22 +441,13 @@ class ModelsStore {
|
|||
/**
|
||||
* Auto-selects the model from the last assistant response if available and loaded.
|
||||
* Returns true if a model was selected, false otherwise.
|
||||
* This is used by the chat page to maintain model consistency across page navigation.
|
||||
*/
|
||||
async selectModelFromLastAssistantResponse(): Promise<boolean> {
|
||||
const lastModel = this.getModelFromLastAssistantResponse();
|
||||
if (!lastModel) return false;
|
||||
|
||||
// Skip if already selected
|
||||
if (this.selectedModelName === lastModel) return false;
|
||||
if (!lastModel || this.selectedModelName === lastModel) return false;
|
||||
|
||||
const matchingModel = this.models.find((option) => option.model === lastModel);
|
||||
if (!matchingModel) return false;
|
||||
|
||||
if (!this.isModelLoaded(lastModel)) {
|
||||
console.log('[modelsStore] last assistant model not loaded:', lastModel);
|
||||
return false;
|
||||
}
|
||||
if (!matchingModel || !this.isModelLoaded(lastModel)) return false;
|
||||
|
||||
try {
|
||||
await this.selectModelById(matchingModel.id);
|
||||
|
|
@ -484,22 +460,17 @@ class ModelsStore {
|
|||
}
|
||||
|
||||
/**
|
||||
* Auto-selects the first available model if none is selected, and fetches its props.
|
||||
* Auto-selects the first available model if none is selected.
|
||||
* Prioritizes:
|
||||
* 1. Model from active conversation's last assistant response (if loaded)
|
||||
* 2. Model from active conversation's last assistant response (if not loaded)
|
||||
* 3. First loaded model (not from active conversation)
|
||||
* 4. First available model
|
||||
* This is used to ensure default values are populated in settings pages.
|
||||
*/
|
||||
async ensureFirstModelSelected(): Promise<void> {
|
||||
if (this.selectedModelName) return;
|
||||
|
||||
// Filter models that are visible in the UI
|
||||
const availableModels = this.models.filter(
|
||||
(option) => this.getModelProps(option.model)?.ui !== false
|
||||
);
|
||||
|
||||
const availableModels = this.getVisibleModels();
|
||||
if (availableModels.length === 0) return;
|
||||
|
||||
// Try to select model from last assistant response first
|
||||
|
|
@ -515,7 +486,7 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
// Try to find a loaded model first
|
||||
// Try a loaded model first
|
||||
const loadedModel = availableModels.find((m) => this.isModelLoaded(m.model));
|
||||
if (loadedModel) {
|
||||
await this.selectModelById(loadedModel.id);
|
||||
|
|
@ -524,34 +495,7 @@ class ModelsStore {
|
|||
}
|
||||
|
||||
// Fall back to the first available model
|
||||
const firstModel = availableModels[0];
|
||||
await this.selectModelById(firstModel.id);
|
||||
// Don't fetch props for unloaded models (will fail in ROUTER mode)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update modalities for a specific model
|
||||
* Called when a model is loaded or when we need fresh modality data
|
||||
*/
|
||||
async updateModelModalities(modelId: string): Promise<void> {
|
||||
try {
|
||||
const props = await this.fetchModelProps(modelId);
|
||||
if (!props?.modalities) return;
|
||||
|
||||
const modalities: ModelModalities = {
|
||||
vision: props.modalities.vision ?? false,
|
||||
audio: props.modalities.audio ?? false,
|
||||
video: props.modalities.video ?? false
|
||||
};
|
||||
|
||||
this.models = this.models.map((model) =>
|
||||
model.model === modelId ? { ...model, modalities } : model
|
||||
);
|
||||
|
||||
this.propsCacheVersion++;
|
||||
} catch (error) {
|
||||
console.warn(`Failed to update modalities for model ${modelId}:`, error);
|
||||
}
|
||||
await this.selectModelById(availableModels[0].id);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -562,9 +506,6 @@ class ModelsStore {
|
|||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Select a model for new conversations
|
||||
*/
|
||||
async selectModelById(modelId: string): Promise<void> {
|
||||
if (!modelId || this.updating) return;
|
||||
if (this.selectedModelId === modelId) return;
|
||||
|
|
@ -584,8 +525,7 @@ class ModelsStore {
|
|||
}
|
||||
|
||||
/**
|
||||
* Select a model by its model name (used for syncing with conversation model)
|
||||
* @param modelName - Model name to select (e.g., "ggml-org/GLM-4.7-Flash-GGUF")
|
||||
* Select a model by its model name (used for syncing with conversation model).
|
||||
*/
|
||||
selectModelByName(modelName: string): void {
|
||||
const option = this.models.find((model) => model.model === modelName);
|
||||
|
|
@ -615,7 +555,7 @@ class ModelsStore {
|
|||
/**
|
||||
*
|
||||
*
|
||||
* Loading/Unloading Models
|
||||
* Loading / Unloading Models
|
||||
*
|
||||
*
|
||||
*/
|
||||
|
|
@ -623,27 +563,18 @@ class ModelsStore {
|
|||
/**
|
||||
* WORKAROUND: Polling for model status after load/unload operations.
|
||||
*
|
||||
* Currently, the `/models/load` and `/models/unload` endpoints return success
|
||||
* before the operation actually completes on the server. This means an immediate
|
||||
* request to `/models` returns stale status (e.g., "loading" after load request,
|
||||
* "loaded" after unload request).
|
||||
* Currently, `/models/load` and `/models/unload` return success before
|
||||
* the operation actually completes on the server.
|
||||
*
|
||||
* TODO: Remove this polling once llama-server properly waits for the operation
|
||||
* to complete before returning success from `/load` and `/unload` endpoints.
|
||||
* At that point, a single `fetchRouterModels()` call after the operation will
|
||||
* be sufficient to get the correct status.
|
||||
* TODO: Remove polling once llama-server properly waits for the operation
|
||||
* to complete before returning success.
|
||||
*/
|
||||
|
||||
/** Polling interval in ms for checking model status */
|
||||
private static readonly STATUS_POLL_INTERVAL = 500;
|
||||
|
||||
/**
|
||||
* Poll for expected model status after load/unload operation.
|
||||
* Keeps polling indefinitely until the model reaches the expected status or fails.
|
||||
*
|
||||
* @param modelId - Model identifier to check
|
||||
* @param expectedStatus - Expected status to wait for
|
||||
* @throws Error if model reaches FAILED status
|
||||
* Keeps polling until the model reaches the expected status or fails.
|
||||
*/
|
||||
private async pollForModelStatus(
|
||||
modelId: string,
|
||||
|
|
@ -654,9 +585,7 @@ class ModelsStore {
|
|||
await this.fetchRouterModels();
|
||||
|
||||
const currentStatus = this.getModelStatus(modelId);
|
||||
if (currentStatus === expectedStatus) {
|
||||
return;
|
||||
}
|
||||
if (currentStatus === expectedStatus) return;
|
||||
|
||||
if (currentStatus === ServerModelStatus.FAILED) {
|
||||
throw new Error(
|
||||
|
|
@ -677,15 +606,8 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a model (ROUTER mode)
|
||||
* @param modelId - Model identifier to load
|
||||
*/
|
||||
async loadModel(modelId: string): Promise<void> {
|
||||
if (this.isModelLoaded(modelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.isModelLoaded(modelId)) return;
|
||||
if (this.modelLoadingStates.get(modelId)) return;
|
||||
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
|
|
@ -694,7 +616,6 @@ class ModelsStore {
|
|||
try {
|
||||
await ModelsService.load(modelId);
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
|
||||
|
||||
await this.updateModelModalities(modelId);
|
||||
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
|
|
@ -706,15 +627,8 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload a model (ROUTER mode)
|
||||
* @param modelId - Model identifier to unload
|
||||
*/
|
||||
async unloadModel(modelId: string): Promise<void> {
|
||||
if (!this.isModelLoaded(modelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.isModelLoaded(modelId)) return;
|
||||
if (this.modelLoadingStates.get(modelId)) return;
|
||||
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
|
|
@ -722,7 +636,6 @@ class ModelsStore {
|
|||
|
||||
try {
|
||||
await ModelsService.unload(modelId);
|
||||
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
|
|
@ -734,15 +647,8 @@ class ModelsStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure a model is loaded before use
|
||||
* @param modelId - Model identifier to ensure is loaded
|
||||
*/
|
||||
async ensureModelLoaded(modelId: string): Promise<void> {
|
||||
if (this.isModelLoaded(modelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.isModelLoaded(modelId)) return;
|
||||
await this.loadModel(modelId);
|
||||
}
|
||||
|
||||
|
|
@ -779,11 +685,9 @@ class ModelsStore {
|
|||
private loadFavoritesFromStorage(): Set<string> {
|
||||
try {
|
||||
const raw = localStorage.getItem(FAVORITE_MODELS_LOCALSTORAGE_KEY);
|
||||
|
||||
return raw ? new Set(JSON.parse(raw) as string[]) : new Set();
|
||||
} catch {
|
||||
toast.error('Failed to load favorite models from local storage');
|
||||
|
||||
return new Set();
|
||||
}
|
||||
}
|
||||
|
|
@ -799,10 +703,19 @@ class ModelsStore {
|
|||
private toDisplayName(id: string): string {
|
||||
const segments = id.split(/\\|\//);
|
||||
const candidate = segments.pop();
|
||||
|
||||
return candidate && candidate.trim().length > 0 ? candidate : id;
|
||||
}
|
||||
|
||||
private buildModalities(
|
||||
modalities: NonNullable<ApiLlamaCppServerProps['modalities']>
|
||||
): ModelModalities {
|
||||
return {
|
||||
vision: modalities.vision ?? false,
|
||||
audio: modalities.audio ?? false,
|
||||
video: modalities.video ?? false
|
||||
};
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.models = [];
|
||||
this.routerModels = [];
|
||||
|
|
|
|||
1
tools/ui/src/lib/types/api.d.ts
vendored
1
tools/ui/src/lib/types/api.d.ts
vendored
|
|
@ -203,6 +203,7 @@ export interface ApiLlamaCppServerProps {
|
|||
/** @deprecated Use {@link ui_settings} instead */
|
||||
webui_settings?: Record<string, string | number | boolean>;
|
||||
ui_settings?: Record<string, string | number | boolean>;
|
||||
cors_proxy_enabled?: boolean;
|
||||
}
|
||||
|
||||
export interface ApiChatCompletionRequest {
|
||||
|
|
|
|||
4
tools/ui/src/lib/types/mcp.d.ts
vendored
4
tools/ui/src/lib/types/mcp.d.ts
vendored
|
|
@ -1,5 +1,5 @@
|
|||
import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp';
|
||||
import type { ToolSource } from '$lib/enums/tools';
|
||||
import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp.enums';
|
||||
import type { ToolSource } from '$lib/enums/tools.enums';
|
||||
import type {
|
||||
Client,
|
||||
ClientCapabilities as SDKClientCapabilities,
|
||||
|
|
|
|||
|
|
@ -12,17 +12,21 @@ export async function validateApiKey(fetch: typeof globalThis.fetch): Promise<vo
|
|||
return;
|
||||
}
|
||||
|
||||
const apiKey = config().apiKey;
|
||||
|
||||
// No API key configured — server doesn't require auth, skip the request entirely.
|
||||
// The /props endpoint is only protected when the server has API keys configured,
|
||||
// and in that case the client always has one set (from settings).
|
||||
if (!apiKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const apiKey = config().apiKey;
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json'
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`
|
||||
};
|
||||
|
||||
if (apiKey) {
|
||||
headers.Authorization = `Bearer ${apiKey}`;
|
||||
}
|
||||
|
||||
const response = await fetch(`${base}/props`, { headers });
|
||||
|
||||
if (!response.ok) {
|
||||
|
|
|
|||
|
|
@ -333,7 +333,8 @@ async function migrateConversation(convId: string): Promise<number> {
|
|||
export async function runLegacyMigration(): Promise<void> {
|
||||
if (!isMigrationNeeded()) return;
|
||||
|
||||
console.log('[Migration] Starting legacy message format migration...');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
|
||||
console.log('[Migration] Starting legacy message format migration...');
|
||||
|
||||
try {
|
||||
const conversations = await DatabaseService.getAllConversations();
|
||||
|
|
@ -344,12 +345,14 @@ export async function runLegacyMigration(): Promise<void> {
|
|||
totalMigrated += count;
|
||||
}
|
||||
|
||||
if (totalMigrated > 0) {
|
||||
console.log(
|
||||
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
|
||||
);
|
||||
} else {
|
||||
console.log('[Migration] No legacy messages found, marking as done');
|
||||
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
|
||||
if (totalMigrated > 0) {
|
||||
console.log(
|
||||
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
|
||||
);
|
||||
} else {
|
||||
console.log('[Migration] No legacy messages found, marking as done');
|
||||
}
|
||||
}
|
||||
|
||||
markMigrationDone();
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore, isConversationsInitialized } from '$lib/stores/conversations.svelte';
|
||||
import { modelsStore, modelOptions } from '$lib/stores/models.svelte';
|
||||
import { isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { page } from '$app/state';
|
||||
import { replaceState } from '$app/navigation';
|
||||
|
|
@ -72,23 +71,13 @@
|
|||
conversationsStore.clearActiveConversation();
|
||||
chatStore.clearUIState();
|
||||
|
||||
if (
|
||||
isRouterMode() &&
|
||||
modelsStore.selectedModelName &&
|
||||
!modelsStore.isModelLoaded(modelsStore.selectedModelName)
|
||||
) {
|
||||
modelsStore.clearSelection();
|
||||
await modelsStore.fetch();
|
||||
|
||||
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
|
||||
if (first) {
|
||||
await modelsStore.selectModelById(first.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle URL params only if we have ?q= or ?model= or ?new_chat=true
|
||||
if (qParam !== null || modelParam !== null || newChatParam === 'true') {
|
||||
await handleUrlParams();
|
||||
}
|
||||
|
||||
await modelsStore.ensureFirstModelSelected();
|
||||
});
|
||||
</script>
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,13 @@
|
|||
import { untrack } from 'svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { fade } from 'svelte/transition';
|
||||
|
||||
import {
|
||||
DesktopIconStrip,
|
||||
DialogConversationTitleUpdate,
|
||||
SidebarNavigation
|
||||
} from '$lib/components/app';
|
||||
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
|
|
@ -30,26 +32,29 @@
|
|||
import { conversations } from '$lib/stores/conversations.svelte';
|
||||
|
||||
let { children } = $props();
|
||||
|
||||
let alwaysShowSidebarOnDesktop = $derived(config().alwaysShowSidebarOnDesktop);
|
||||
let isMobile = new IsMobile();
|
||||
let isDesktop = $derived(!isMobile.current);
|
||||
let sidebarOpen = $state(false);
|
||||
let mounted = $state(false);
|
||||
let innerHeight = $state<number | undefined>();
|
||||
|
||||
let chatSidebar:
|
||||
| { activateSearchMode?: () => void; editActiveConversation?: () => void }
|
||||
| {
|
||||
activateSearchMode?: () => void;
|
||||
editActiveConversation?: () => void;
|
||||
}
|
||||
| undefined = $state();
|
||||
|
||||
let titleUpdateDialogOpen = $state(false);
|
||||
let titleUpdateCurrentTitle = $state('');
|
||||
let titleUpdateNewTitle = $state('');
|
||||
let titleUpdateResolve: ((value: boolean) => void) | null = null;
|
||||
|
||||
const panelNav = useSettingsNavigation();
|
||||
|
||||
function navigateToConversation(direction: -1 | 1) {
|
||||
const allConvs = conversations();
|
||||
|
||||
if (allConvs.length === 0) return;
|
||||
|
||||
const currentId = page.params.id;
|
||||
|
|
@ -61,6 +66,7 @@
|
|||
}
|
||||
|
||||
const idx = allConvs.findIndex((c) => c.id === currentId);
|
||||
|
||||
if (idx === -1) return;
|
||||
|
||||
const targetIdx = idx + direction;
|
||||
|
|
@ -75,38 +81,41 @@
|
|||
// Global keyboard shortcuts
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
editActiveConversation: () => chatSidebar?.editActiveConversation?.(),
|
||||
|
||||
navigateToPrevConversation: () => navigateToConversation(-1),
|
||||
|
||||
navigateToNextConversation: () => navigateToConversation(1)
|
||||
});
|
||||
|
||||
function checkApiKey() {
|
||||
const apiKey = config().apiKey;
|
||||
|
||||
if (
|
||||
(page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') &&
|
||||
page.status !== 401 &&
|
||||
page.status !== 403
|
||||
) {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json'
|
||||
};
|
||||
|
||||
if (apiKey && apiKey.trim() !== '') {
|
||||
headers.Authorization = `Bearer ${apiKey.trim()}`;
|
||||
}
|
||||
|
||||
fetch(`${base}/props`, { headers })
|
||||
.then((response) => {
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
window.location.reload();
|
||||
}
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error('Error checking API key:', e);
|
||||
});
|
||||
// No API key configured — server doesn't require auth, no need to validate.
|
||||
// This mirrors the early return in validateApiKey() to avoid redundant /props requests.
|
||||
if (!apiKey || apiKey.trim() === '') {
|
||||
return;
|
||||
}
|
||||
|
||||
untrack(() => {
|
||||
if (
|
||||
(page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') &&
|
||||
page.status !== 401 &&
|
||||
page.status !== 403
|
||||
) {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey.trim()}`
|
||||
};
|
||||
|
||||
fetch(`${base}/props`, { headers })
|
||||
.then((response) => {
|
||||
if (response.status === 401 || response.status === 403) {
|
||||
window.location.reload();
|
||||
}
|
||||
})
|
||||
.catch((e) => {
|
||||
console.error('Error checking API key:', e);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function handleTitleUpdateCancel() {
|
||||
|
|
@ -134,6 +143,7 @@
|
|||
$effect(() => {
|
||||
if (alwaysShowSidebarOnDesktop && isDesktop) {
|
||||
sidebarOpen = true;
|
||||
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
|
@ -170,6 +180,7 @@
|
|||
// Only fetch router models once when we have models loaded and in router mode
|
||||
if (isRouter && modelsCount > 0 && !routerModelsFetched) {
|
||||
routerModelsFetched = true;
|
||||
|
||||
untrack(() => {
|
||||
modelsStore.fetchRouterModels();
|
||||
});
|
||||
|
|
@ -218,7 +229,6 @@
|
|||
|
||||
<Tooltip.Provider delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<ModeWatcher />
|
||||
|
||||
<Toaster richColors />
|
||||
|
||||
<DialogConversationTitleUpdate
|
||||
|
|
@ -230,10 +240,10 @@
|
|||
/>
|
||||
|
||||
<Sidebar.Provider bind:open={sidebarOpen}>
|
||||
<div class="flex h-screen w-full" style:height="{innerHeight}px">
|
||||
<Sidebar.Root variant="floating" class="h-full">
|
||||
<SidebarNavigation bind:this={chatSidebar} />
|
||||
</Sidebar.Root>
|
||||
<div class="flex h-screen w-full">
|
||||
<Sidebar.Root variant="floating" class="h-full"
|
||||
><SidebarNavigation bind:this={chatSidebar} /></Sidebar.Root
|
||||
>
|
||||
|
||||
{#if !(alwaysShowSidebarOnDesktop && isDesktop) && !(panelNav.isSettingsRoute && !isDesktop)}
|
||||
{#if mounted}
|
||||
|
|
@ -261,9 +271,9 @@
|
|||
/>
|
||||
{/if}
|
||||
|
||||
<Sidebar.Inset class="flex flex-1 flex-col overflow-hidden">
|
||||
{@render children?.()}
|
||||
</Sidebar.Inset>
|
||||
<Sidebar.Inset class="flex flex-1 flex-col overflow-hidden"
|
||||
>{@render children?.()}</Sidebar.Inset
|
||||
>
|
||||
</div>
|
||||
</Sidebar.Provider>
|
||||
</Tooltip.Provider>
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ $use-ttf: false;
|
|||
$font-folder: 'katex-fonts';
|
||||
|
||||
// Import KaTeX SCSS with overridden variables
|
||||
// Note: @import is deprecated but required because KaTeX uses @import internally
|
||||
// The deprecation warnings are from KaTeX's code and cannot be avoided
|
||||
@import 'katex/src/styles/katex.scss';
|
||||
@use 'katex/src/styles/katex.scss' with (
|
||||
$use-woff2: true,
|
||||
$use-woff: false,
|
||||
$use-ttf: false,
|
||||
$font-folder: 'katex-fonts'
|
||||
);
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ import TestWrapper from './components/TestWrapper.svelte';
|
|||
|
||||
describe('/+page.svelte', () => {
|
||||
it('should render page without throwing', async () => {
|
||||
// Basic smoke test - page should render without throwing errors
|
||||
// API calls will fail in test environment but component should still mount
|
||||
expect(() => render(TestWrapper)).not.toThrow();
|
||||
// Basic smoke test - page should render without throwing errors.
|
||||
// API calls are mocked in vitest-setup-client.ts.
|
||||
await render(TestWrapper);
|
||||
expect(true).toBe(true);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -23,18 +23,6 @@ export default defineConfig({
|
|||
minify: true
|
||||
},
|
||||
|
||||
css: {
|
||||
preprocessorOptions: {
|
||||
scss: {
|
||||
additionalData: `
|
||||
$use-woff2: true;
|
||||
$use-woff: false;
|
||||
$use-ttf: false;
|
||||
`
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
|
||||
|
||||
test: {
|
||||
|
|
|
|||
|
|
@ -1,2 +1,80 @@
|
|||
/// <reference types="@vitest/browser/matchers" />
|
||||
/// <reference types="@vitest/browser/providers/playwright" />
|
||||
|
||||
import { beforeEach, vi } from 'vitest';
|
||||
|
||||
// Mock fetch for API calls during client tests.
|
||||
// In test environment there is no backend server, so we intercept
|
||||
// the specific endpoints the app uses and return valid mock data.
|
||||
beforeEach(() => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
vi.spyOn(globalThis, 'fetch').mockImplementation(
|
||||
async (input: RequestInfo | URL, init?: RequestInit) => {
|
||||
const url = typeof input === 'string' ? input : input instanceof URL ? input.href : input.url;
|
||||
|
||||
// Mock server props endpoint
|
||||
if (url.includes('/server')) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
mode: 'router',
|
||||
version: 'test',
|
||||
git_commit: 'test',
|
||||
git_branch: 'test'
|
||||
}),
|
||||
{ status: 200, headers: { 'Content-Type': 'application/json' } }
|
||||
);
|
||||
}
|
||||
|
||||
// Mock models list endpoint
|
||||
if (/\/v1\/models|\/models\b/.test(url)) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
object: 'list',
|
||||
data: [
|
||||
{
|
||||
id: 'test-model.gguf',
|
||||
object: 'model',
|
||||
owned_by: 'llamacpp',
|
||||
created: 0,
|
||||
in_cache: false,
|
||||
path: 'models/test-model.gguf',
|
||||
status: { value: 'unloaded' },
|
||||
meta: {}
|
||||
}
|
||||
],
|
||||
models: [
|
||||
{
|
||||
model: 'test-model.gguf',
|
||||
name: 'Test Model',
|
||||
details: {}
|
||||
}
|
||||
]
|
||||
}),
|
||||
{ status: 200, headers: { 'Content-Type': 'application/json' } }
|
||||
);
|
||||
}
|
||||
|
||||
// Mock /props endpoint (used for modalities)
|
||||
if (url.includes('/props')) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
default_generation_settings: { n_ctx: 2048 }
|
||||
}),
|
||||
{ status: 200, headers: { 'Content-Type': 'application/json' } }
|
||||
);
|
||||
}
|
||||
|
||||
// Mock /tools endpoint (used for built-in tools list)
|
||||
if (url.includes('/tools')) {
|
||||
return new Response(JSON.stringify([]), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
}
|
||||
|
||||
// Default: use real fetch
|
||||
return originalFetch(input, init);
|
||||
}
|
||||
);
|
||||
});
|
||||
|
|
|
|||
1
tools/ui/vitest.shims.d.ts
vendored
Normal file
1
tools/ui/vitest.shims.d.ts
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
/// <reference types="@vitest/browser-playwright" />
|
||||
Loading…
Add table
Add a link
Reference in a new issue