Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/cann.Dockerfile
#	.devops/cpu.Dockerfile
#	.devops/cuda.Dockerfile
#	.devops/intel.Dockerfile
#	.devops/llama-cli-cann.Dockerfile
#	.devops/musa.Dockerfile
#	.devops/openvino.Dockerfile
#	.devops/rocm.Dockerfile
#	.devops/s390x.Dockerfile
#	.devops/vulkan.Dockerfile
#	.github/ISSUE_TEMPLATE/011-bug-results.yml
#	.github/ISSUE_TEMPLATE/019-bug-misc.yml
#	.github/workflows/build-and-test-snapdragon.yml
#	.github/workflows/docker.yml
#	.github/workflows/server-self-hosted.yml
#	.github/workflows/ui-ci.yml
#	.pi/gg/SYSTEM.md
#	README.md
#	common/arg.cpp
#	docs/backend/SYCL.md
#	docs/backend/snapdragon/CMakeUserPresets.json
#	docs/backend/snapdragon/README.md
#	docs/speculative.md
#	examples/save-load-state/save-load-state.cpp
#	ggml/src/ggml-hexagon/ggml-hexagon.cpp
#	ggml/src/ggml-hexagon/htp/CMakeLists.txt
#	ggml/src/ggml-hexagon/htp/htp-ctx.h
#	ggml/src/ggml-hexagon/htp/htp-ops.h
#	ggml/src/ggml-hexagon/htp/main.c
#	ggml/src/ggml-hexagon/htp/rope-ops.c
#	ggml/src/ggml-hexagon/htp/unary-ops.c
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/cvt.cl
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl
#	tools/cli/README.md
#	tools/server/README.md
This commit is contained in:
Concedo 2026-05-20 18:48:34 +08:00
commit 7d987af23a
77 changed files with 1291 additions and 1048 deletions

View file

@ -4,7 +4,6 @@
#include "chat.h"
#include "common.h"
#include "download.h"
#include "hf-cache.h"
#include "json-schema-to-grammar.h"
#include "log.h"
#include "sampling.h"
@ -539,7 +538,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
const bool skip = (arg == "--spec-type");
if (!skip) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
@ -588,12 +591,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// TODO: Remove later // KCPP: remove for now
// try {
// hf_cache::migrate_old_cache_to_hf_cache(params.hf_token, params.offline);
// } catch (const std::exception & e) {
// LOG_WRN("HF cache migration failed: %s\n", e.what());
// }
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
@ -902,7 +899,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
const bool skip = (arg == "--spec-type");
if (!skip) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
}
auto opt = *arg_to_options[arg];
std::string val;
@ -3365,7 +3366,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
" - 1: error\n"
" - 2: warning\n"
" - 3: info\n"
" - 4: debug\n"
" - 4: trace (more info)\n"
" - 5: debug\n"
"(default: %d)\n", params.verbosity),
[](common_params & params, int value) {
params.verbosity = value;
@ -4126,6 +4128,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_mod.n_match = 24;
params.speculative.ngram_mod.n_min = 48;
params.speculative.ngram_mod.n_max = 64;
// TODO: not sure if this is a good config - explore more settings and potentially enable it
//params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
//params.speculative.ngram_map_k4v.size_n = 8;
//params.speculative.ngram_map_k4v.size_m = 24;
//params.speculative.ngram_map_k4v.min_hits = 2;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));

View file

@ -1166,7 +1166,7 @@ struct common_init_result::impl {
std::vector<llama_sampler_seq_config> samplers_seq_config;
};
common_init_result::common_init_result(common_params & params) :
common_init_result::common_init_result(common_params & params, bool model_only) :
pimpl(new impl{}) {
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
@ -1179,7 +1179,7 @@ common_init_result::common_init_result(common_params & params) :
params.tensor_buft_overrides.data(),
params.fit_params_target.data(),
params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
@ -1189,6 +1189,10 @@ common_init_result::common_init_result(common_params & params) :
pimpl->model.reset(model);
if (model_only) {
return;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// load and optionally apply lora adapters
@ -1258,29 +1262,6 @@ common_init_result::common_init_result(common_params & params) :
cparams.n_samplers = pimpl->samplers_seq_config.size();
}
// [TAG_RS_STATE_ROLLBACK_SUPPORT]
// TODO: ngram speculative methods require checkpointing in addition to partial RS rollback
// currently this is not supported. so we disable the partial rollback
if (cparams.n_rs_seq > 0 && (llama_model_is_recurrent(model) || llama_model_is_hybrid(model))) {
auto & types = params.speculative.types;
for (int i = 0; i < (int) types.size(); i++) {
if (types[i] == COMMON_SPECULATIVE_TYPE_NONE) {
continue;
}
if (types[i] == COMMON_SPECULATIVE_TYPE_DRAFT_MTP) {
continue;
}
cparams.n_rs_seq = 0;
LOG_WRN("%s: recurrent state rollback is not compatible with '%s' - disabling rollback support\n", __func__,
common_speculative_type_to_str(types[i]).c_str());
break;
}
}
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
@ -1315,8 +1296,8 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
common_init_result_ptr common_init_from_params(common_params & params, bool model_only) {
common_init_result_ptr res(new common_init_result(params, model_only));
llama_model * model = res->model();
if (model == NULL) {
@ -1324,6 +1305,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
return res;
}
if (model_only) {
return res;
}
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
@ -1387,7 +1372,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
}
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
llama_set_warmup(lctx, true);

View file

@ -300,11 +300,11 @@ struct common_params_model {
// draft-model-based speculative decoding parameters
struct common_params_speculative_draft {
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_max = 3; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy) // TODO: change default to 0.0f
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.0f; // minimum speculative decoding probability (greedy)
common_params_model mparams;
@ -858,7 +858,7 @@ struct common_sampler;
// note: defines the model, context, samplers, ets. lifetimes
struct common_init_result {
common_init_result(common_params & params);
common_init_result(common_params & params, bool model_only = false);
~common_init_result();
llama_model * model();
@ -876,7 +876,7 @@ private:
using common_init_result_ptr = std::unique_ptr<common_init_result>;
common_init_result_ptr common_init_from_params(common_params & params);
common_init_result_ptr common_init_from_params(common_params & params, bool model_only = false);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);

View file

@ -11,7 +11,6 @@
#include <filesystem>
#include <fstream>
#include <atomic>
#include <regex> // migration only
#include <string>
#include <string_view>
#include <stdexcept>
@ -336,15 +335,9 @@ hf_files get_repo_files(const std::string & repo_id,
if (item["lfs"].contains("oid") && item["lfs"]["oid"].is_string()) {
file.oid = item["lfs"]["oid"].get<std::string>();
}
if (item["lfs"].contains("size") && item["lfs"]["size"].is_number()) {
file.size = item["lfs"]["size"].get<size_t>();
}
} else if (item.contains("oid") && item["oid"].is_string()) {
file.oid = item["oid"].get<std::string>();
}
if (file.size == 0 && item.contains("size") && item["size"].is_number()) {
file.size = item["size"].get<size_t>();
}
if (!file.oid.empty() && !is_valid_oid(file.oid)) {
LOG_WRN("%s: skip invalid oid: %s\n", __func__, file.oid.c_str());
@ -502,271 +495,4 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}
// delete everything after this line, one day
// copied from download.cpp without the tag part
struct gguf_split_info {
std::string prefix; // tag included
int index;
int count;
};
static gguf_split_info get_gguf_split_info(const std::string & path) {
static const std::regex re_split("^(.+)-([0-9]{5})-of-([0-9]{5})$", std::regex::icase);
std::smatch m;
std::string prefix = path;
if (!string_remove_suffix(prefix, ".gguf")) {
return {};
}
int index = 1;
int count = 1;
if (std::regex_match(prefix, m, re_split)) {
index = std::stoi(m[2].str());
count = std::stoi(m[3].str());
prefix = m[1].str();
}
return {std::move(prefix), index, count};
}
static std::pair<std::string, std::string> parse_manifest_name(std::string & filename) {
static const std::regex re(R"(^manifest=([^=]+)=([^=]+)=.*\.json$)");
std::smatch match;
if (std::regex_match(filename, match, re)) {
return {match[1].str(), match[2].str()};
}
return {};
}
static std::string make_old_cache_filename(const std::string & owner,
const std::string & repo,
const std::string & filename) {
auto result = owner + "_" + repo + "_" + filename;
string_replace_all(result, "/", "_");
return result;
}
struct migrate_file {
std::string path;
std::string sha256;
size_t size;
fs::path old_path;
fs::path etag_path;
const hf_file * file;
};
using migrate_files = std::vector<migrate_file>;
static bool collect_file(const fs::path & old_cache,
const std::string & owner,
const std::string & repo,
const std::string & path,
const std::string & sha256,
const hf_files & files,
migrate_files & to_migrate) {
const hf_file * file = nullptr;
for (const auto & f : files) {
if (f.path == path) {
file = &f;
break;
}
}
std::string old_filename = make_old_cache_filename(owner, repo, path);
fs::path old_path = old_cache / old_filename;
fs::path etag_path = old_path.string() + ".etag";
if (!fs::exists(old_path)) {
if (file && fs::exists(file->final_path)) {
return true;
}
LOG_WRN("%s: %s not found in old cache or HF cache\n", __func__, old_filename.c_str());
return false;
}
if (!file) {
LOG_WRN("%s: %s not found in current repo\n", __func__, old_filename.c_str());
return false;
}
if (!sha256.empty() && !file->oid.empty() && sha256 != file->oid) {
LOG_WRN("%s: %s is not up to date (sha256 mismatch)\n", __func__, old_filename.c_str());
return false;
}
if (file->size > 0) {
size_t size = fs::file_size(old_path);
if (size != file->size) {
LOG_WRN("%s: %s has wrong size %zu (expected %zu)\n", __func__, old_filename.c_str(), size, file->size);
return false;
}
}
to_migrate.push_back({path, sha256, file->size, old_path, etag_path, file});
return true;
}
static bool collect_files(const fs::path & old_cache,
const std::string & owner,
const std::string & repo,
const nl::json & node,
const hf_files & files,
migrate_files & to_migrate) {
if (!node.contains("rfilename") ||
!node.contains("lfs") ||
!node["lfs"].contains("sha256")) {
return true;
}
std::string path = node["rfilename"];
std::string sha256 = node["lfs"]["sha256"];
auto split = get_gguf_split_info(path);
if (split.count <= 1) {
return collect_file(old_cache, owner, repo, path, sha256, files, to_migrate);
}
std::vector<std::pair<std::string, std::string>> splits;
for (const auto & f : files) {
auto split_f = get_gguf_split_info(f.path);
if (split_f.count == split.count && split_f.prefix == split.prefix) {
// sadly the manifest only provides the sha256 of the first file (index == 1)
// the rest will be verified using the size...
std::string f_sha256 = (split_f.index == 1) ? sha256 : "";
splits.emplace_back(f.path, f_sha256);
}
}
if ((int)splits.size() != split.count) {
LOG_WRN("%s: expected %d split files but found %d in repo\n", __func__, split.count, (int)splits.size());
return false;
}
for (const auto & [f_path, f_sha256] : splits) {
if (!collect_file(old_cache, owner, repo, f_path, f_sha256, files, to_migrate)) {
return false;
}
}
return true;
}
static bool migrate_file(const migrate_file & file) {
std::error_code ec;
fs::path new_path(file.file->local_path);
fs::create_directories(new_path.parent_path(), ec);
if (!fs::exists(new_path, ec)) {
fs::rename(file.old_path, new_path, ec);
if (ec) {
fs::copy_file(file.old_path, new_path, ec);
if (ec) {
LOG_ERR("%s: failed to move/copy %s: %s\n", __func__, file.old_path.string().c_str(), ec.message().c_str());
return false;
}
}
fs::remove(file.old_path, ec);
}
fs::remove(file.etag_path, ec);
std::string filename = finalize_file(*file.file);
LOG_INF("%s: migrated %s -> %s\n", __func__, file.old_path.filename().string().c_str(), filename.c_str());
return true;
}
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline) {
fs::path old_cache = fs_get_cache_directory();
if (!fs::exists(old_cache)) {
return;
}
if (offline) {
LOG_WRN("%s: skipping migration in offline mode (will run when online)\n", __func__);
return; // -hf is not going to work
}
bool warned = false;
for (const auto & entry : fs::directory_iterator(old_cache)) {
if (!entry.is_regular_file()) {
continue;
}
auto filename = entry.path().filename().string();
auto [owner, repo] = parse_manifest_name(filename);
if (owner.empty() || repo.empty()) {
continue;
}
if (!warned) {
warned = true;
LOG_WRN("================================================================================\n"
"WARNING: Migrating cache to HuggingFace cache directory\n"
" Old cache: %s\n"
" New cache: %s\n"
"This one-time migration moves models previously downloaded with -hf\n"
"from the legacy llama.cpp cache to the standard HuggingFace cache.\n"
"Models downloaded with --model-url are not affected.\n"
"================================================================================\n",
old_cache.string().c_str(), get_cache_directory().string().c_str());
}
auto repo_id = owner + "/" + repo;
auto files = get_repo_files(repo_id, token);
if (files.empty()) {
LOG_WRN("%s: could not get repo files for %s, skipping\n", __func__, repo_id.c_str());
continue;
}
migrate_files to_migrate;
bool ok = true;
try {
std::ifstream manifest(entry.path());
auto json = nl::json::parse(manifest);
for (const char * key : {"ggufFile", "mmprojFile"}) {
if (json.contains(key)) {
if (!collect_files(old_cache, owner, repo, json[key], files, to_migrate)) {
ok = false;
break;
}
}
}
} catch (const std::exception & e) {
LOG_WRN("%s: failed to parse manifest %s: %s\n", __func__, filename.c_str(), e.what());
continue;
}
if (!ok) {
LOG_WRN("%s: migration skipped: one or more files failed validation\n", __func__);
continue;
}
for (const auto & file : to_migrate) {
if (!migrate_file(file)) {
ok = false;
break;
}
}
if (!ok) {
LOG_WRN("%s: migration failed: could not migrate all files\n", __func__);
continue;
}
LOG_INF("%s: migration complete, deleting manifest: %s\n", __func__, entry.path().string().c_str());
fs::remove(entry.path());
}
}
} // namespace hf_cache

View file

@ -14,7 +14,6 @@ struct hf_file {
std::string final_path;
std::string oid;
std::string repo_id;
size_t size = 0; // only for the migration
};
using hf_files = std::vector<hf_file>;
@ -30,7 +29,4 @@ hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);
} // namespace hf_cache

View file

@ -500,7 +500,7 @@ void common_ngram_map_draft(common_ngram_map & map,
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
LOG_DBG("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
key_offset, slot_max,
curr_key.key_num, draft.size());

View file

@ -32,6 +32,19 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
};
static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
if (devices.empty()) {
return "default";
}
std::string result;
for (size_t i = 0; i < devices.size(); i++) {
if (i > 0) result += ", ";
result += ggml_backend_dev_name(devices[i]);
}
return result;
}
struct common_speculative_config {
common_speculative_type type;
common_params_speculative params;
@ -144,7 +157,7 @@ struct common_speculative_impl {
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;
@ -167,6 +180,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
ctx_tgt ? "yes" : "no",
ctx_dft ? "yes" : "no",
common_speculative_get_devices_str(this->params.devices).c_str());
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
// TODO: optimize or pass from outside?
@ -343,7 +366,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
@ -355,8 +378,12 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
//common_params_speculative_eagle3 params;
common_speculative_impl_draft_eagle3(const common_params_speculative & /*params*/, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) {}
common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
{
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min);
}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
@ -371,7 +398,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
// TODO: implement
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
@ -380,7 +407,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
}
};
struct common_speculative_state_draft_mtp : public common_speculative_impl {
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
llama_batch batch;
@ -407,7 +434,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
common_speculative_state_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
, params(params.draft)
{
@ -417,6 +444,16 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
ctx_tgt ? "yes" : "no",
ctx_dft ? "yes" : "no",
common_speculative_get_devices_str(this->params.devices).c_str());
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
@ -427,7 +464,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1; // TODO: re-enable top_k == 10 and utilize `p_min` spec param
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
@ -446,7 +483,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_state_draft_mtp() override {
~common_speculative_impl_draft_mtp() override {
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
@ -462,7 +499,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d "
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
@ -633,6 +670,14 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
@ -678,7 +723,7 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
@ -714,7 +759,12 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
common_ngram_simple_config config)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
, params(params.ngram_simple)
, config(config) {}
, config(config)
{
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
this->params.size_n, this->params.size_m, this->params.min_hits);
}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
@ -738,7 +788,7 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
@ -748,20 +798,21 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
};
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
common_params_speculative_ngram_map params;
// n_seq configs
std::vector<common_ngram_map> config;
common_speculative_impl_ngram_map_k(
const common_params_speculative & params,
const common_ngram_map & config,
uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
, params(params.ngram_map_k) {
{
for (uint32_t i = 0; i < n_seq; i++) {
this->config.push_back(config);
}
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
config.size_key, config.size_value, config.key_only, config.min_hits);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
@ -788,9 +839,13 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
GGML_ASSERT((seq_id < (llama_seq_id) config.size()));
if (is_other) {
return;
}
common_ngram_map_accept(config[seq_id], n_accepted);
}
@ -812,7 +867,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
// the last position in the prompt that was added to the ngram container
size_t i_last = 0;
// length of the last drafted ngram (number of tokens returned by draft)
// length of the last drafted n-gram (number of tokens returned by draft)
size_t n_draft_last = 0;
// consecutive accept rounds with low acceptance fraction (< 0.5)
@ -830,8 +885,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
LOG_INF("%s: initialized ngram_mod with n_match=%d, size=%zu (%.3f MB)\n", __func__,
this->params.n_match, mod.size(), (float)(mod.size_bytes())/1024/1024);
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
this->params.n_match, this->params.n_max, this->params.n_min);
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
mod.size(), (float)(mod.size_bytes())/1024/1024);
if (this->params.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
@ -921,7 +979,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
}
result.resize(result.size() - n);
// store length of drafted ngram for later acceptance analysis
// store length of drafted n-gram for later acceptance analysis
sinfo.n_draft_last = result.size();
}
@ -943,17 +1001,21 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
if (is_other) {
return;
}
auto & sinfo = sinfos[seq_id];
// compute acceptance fraction if we have a recorded draft length
if (sinfo.n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last;
if (f_acc < 0.5) {
if (f_acc < 0.25) {
sinfo.n_low++;
if (sinfo.n_low >= 3) {
if (sinfo.n_low >= 5) {
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) resetting ngram_mod\n", __func__, sinfo.n_low);
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
}
mod.reset();
@ -1003,6 +1065,12 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
, save_dynamic(save_dynamic)
, save_static(save_static)
{
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
n_draft,
path_static.empty() ? "none" : path_static.c_str(),
path_dynamic.empty() ? "none" : path_dynamic.c_str());
sinfos.resize(n_seq);
if (!path_static.empty()) {
@ -1099,7 +1167,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/) override {
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
@ -1285,7 +1353,6 @@ common_speculative * common_speculative_init(common_params_speculative & params,
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
for (const common_speculative_config & config : configs) {
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(config.type).c_str());
switch (config.type) {
case COMMON_SPECULATIVE_TYPE_NONE:
break;
@ -1298,7 +1365,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
impls.push_back(std::make_unique<common_speculative_state_draft_mtp>(config.params, n_seq));
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
@ -1319,11 +1386,16 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::move(state));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
impls.push_back(
std::make_unique<common_speculative_impl_ngram_map_k>(
get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(
std::make_unique<common_speculative_impl_ngram_map_k>(
config.params, get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
@ -1515,11 +1587,6 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
GGML_ASSERT(impl);
// TODO: currently only the implementation that generated the draft is used to accept it
// however, some implementations (such as MTP) need to also "see" the accepted tokens
// extend `common_speculative_impl::accept()` with an extra argument `bool is_other` to
// inform the implementation if the accepted tokens are from another implementation and
// pass the accepted tokens to all remaining implementations using `is_other == true`
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
@ -1527,9 +1594,16 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
impl->n_acc_tokens += n_accepted;
}
impl->accept(seq_id, n_accepted);
impl->accept(seq_id, n_accepted, false);
impl->n_call_accept++;
}
// accept with the rest of the implementations, using is_other == true
for (auto & impl_other : spec->impls) {
if (impl_other.get() != impl) {
impl_other->accept(seq_id, n_accepted, true);
}
}
}
void common_speculative_print_stats(const common_speculative * spec) {
@ -1549,7 +1623,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf = "";
}
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,

View file

@ -115,15 +115,15 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
help="Export multimodal projector (mmproj) for vision models. This will only work on some vision models. An 'mmproj-' prefix will be added to the output file name.",
)
parser.add_argument(
"--mtp", action="store_true",
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
help="Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. An 'mtp-' prefix will be added to the output file name.",
)
parser.add_argument(
"--no-mtp", action="store_true",
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
help="Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, but even though the bundled default is more space-efficient overall, this allows differing quantization which may be more performant.",
)
parser.add_argument(
"--mistral-format", action="store_true",

View file

@ -445,6 +445,11 @@ if __name__ == '__main__':
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
# filter base name, ignore tensor transformations for now
data_gen = lambda g=tensor: g # noqa: E731
if (titem := self.filter_tensors((base_name, data_gen))) is None:
continue
base_name, _ = titem
# note: mergekit-extract-lora also adds token embeddings to the adapter
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name

View file

@ -149,6 +149,8 @@ class TaskState:
t_gen_ms: Optional[float] = None
reasoning_content: Optional[str] = None
server_name: Optional[str] = None
chunk_idx: int = 0
problem_idx: int = 0
class EvalState:
@ -233,7 +235,9 @@ class EvalState:
tps_gen: Optional[float] = None,
t_gen_ms: Optional[float] = None,
reasoning_content: Optional[str] = None,
server_name: Optional[str] = None
server_name: Optional[str] = None,
chunk_idx: int = 0,
problem_idx: int = 0,
):
with self._lock:
if "cases" not in self.task_states:
@ -252,7 +256,9 @@ class EvalState:
"tps_gen": tps_gen,
"t_gen_ms": t_gen_ms,
"reasoning_content": reasoning_content,
"server_name": server_name
"server_name": server_name,
"chunk_idx": chunk_idx,
"problem_idx": problem_idx,
}
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
@ -289,6 +295,9 @@ class EvalState:
all_cases = {}
for i, task_id in tasks_to_save:
question_text, prompt, expected = self.get_case(i)
# Extract chunk_idx from task_id for pending cases
_parts = task_id.rsplit("_", 2)
_chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
@ -306,7 +315,9 @@ class EvalState:
"tps_gen": None,
"t_gen_ms": None,
"reasoning_content": None,
"server_name": None
"server_name": None,
"chunk_idx": _chunk_idx,
"problem_idx": i,
}
ci_lower, ci_upper = self.accuracy_ci()
@ -382,11 +393,12 @@ class EvalState:
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
escaped_server = self._escape_html(server_name)
answer_class = status_class if status == "ok" else ""
rows.append(f"""<tr class="task-row" onclick="toggleDetails('{task_id}')">
<td>{task_id}</td>
<td class="{status_class}">{status_text}</td>
<td>{self._escape_html(expected)}</td>
<td>{self._escape_html(answer)}</td>
<td class="{answer_class}">{self._escape_html(answer)}</td>
<td>{tokens_str}</td>
<td>{tps_str}</td>
<td>{t_gen_str}</td>
@ -405,6 +417,53 @@ class EvalState:
rows_html = "\n".join(rows)
# ---- per-problem summary table ----
problem_groups: Dict[int, List[Dict[str, Any]]] = {}
for _tid, _case in cases.items():
if _case.get("status") != "ok":
continue
_pidx = _case.get("problem_idx")
if _pidx is None:
_p_parts = _tid.rsplit("_", 2)
_pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0
problem_groups.setdefault(_pidx, []).append(_case)
summary_rows_html = ""
if problem_groups:
def _stat(v, fmt=".1f", avg_fmt=None):
if not v:
return ("", "", "")
af = fmt if avg_fmt is None else avg_fmt
return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}")
summary_data = []
for pidx, g in problem_groups.items():
runs = len(g)
n_ok = sum(1 for c in g if c.get("correct", False))
toks = [c["tokens"] for c in g if c.get("tokens") is not None]
tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None]
tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None]
summary_data.append((
pidx, runs, n_ok,
_stat(toks, "d", ".0f"),
_stat(tps),
_stat(tg),
))
summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending
summary_rows_html = "\n".join(
f"""<tr class="summary-row">
<td>{p:03d}</td>
<td>{r}</td>
<td>{n}/{r}</td>
<td>{tk[0]}</td><td>{tk[1]}</td><td>{tk[2]}</td>
<td>{tp[0]}</td><td>{tp[1]}</td><td>{tp[2]}</td>
<td>{tg[0]}</td><td>{tg[1]}</td><td>{tg[2]}</td>
</tr>"""
for p, r, n, tk, tp, tg in summary_data
)
html_content = f"""<!DOCTYPE html>
<html>
<head>
@ -412,10 +471,10 @@ class EvalState:
<title>{self.dataset_type.upper()} Eval</title>
<style>
body {{ font-family: system-ui, sans-serif; margin: 0; padding: 16px; background: #fff; color: #222; }}
.bar {{ padding: 8px 0; font-size: 14px; color: #555; }}
.bar span {{ margin-right: 20px; }}
.bar b {{ color: #222; }}
table {{ width: 100%; border-collapse: collapse; font-size: 13px; }}
.bar {{ padding: 8px 0; font-size: 13px; color: #555; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; display: grid; grid-template-columns: auto 1fr auto 1fr; gap: 2px 12px; align-items: baseline; }}
.bar .label {{ color: #888; }}
.bar .value {{ color: #222; }}
table {{ width: 100%; border-collapse: collapse; font-size: 13px; font-family: 'SF Mono', 'Menlo', 'Consolas', monospace; }}
th {{ text-align: left; padding: 6px 8px; border-bottom: 2px solid #ccc; font-weight: 600; }}
td {{ padding: 4px 8px; border-bottom: 1px solid #eee; vertical-align: top; }}
.task-row {{ cursor: pointer; }}
@ -429,37 +488,88 @@ class EvalState:
.details-content {{ padding: 8px 16px; background: #f6f8fa; font-size: 12px; }}
.details-content b {{ color: #555; }}
.details-content pre {{ background: #fff; border: 1px solid #e1e4e8; padding: 8px; overflow-x: auto; white-space: pre-wrap; word-wrap: break-word; margin: 4px 0 8px; }}
.summary-table {{ margin-bottom: 16px; font-size: 13px; width: 100%; }}
.summary-row {{ background: #fafbfc; }}
.summary-row:hover {{ background: #f5f5f5; }}
.summary-table th {{ text-align: right; font-weight: 600; }}
.summary-table th:first-child {{ text-align: left; }}
.summary-table th[colspan] {{ text-align: center; }}
.summary-table td {{ text-align: right; }}
.summary-table td:first-child {{ text-align: left; }}
.tabs {{ display: flex; border-bottom: 2px solid #ddd; margin: 12px 0 0; }}
.tab-btn {{ padding: 6px 16px; border: none; background: none; font-size: 13px; cursor: pointer; color: #555; border-bottom: 2px solid transparent; margin-bottom: -2px; font-weight: 500; }}
.tab-btn:hover {{ color: #222; }}
.tab-btn.active {{ color: #222; border-bottom-color: #222; font-weight: 600; }}
.tab-content {{ display: none; }}
.tab-content.active {{ display: block; }}
</style>
</head>
<body>
<div class="bar">
<span><b>{self.dataset_type.upper()}</b></span>
<span>Model: {self.model_name or 'N/A'}</span>
<span>Accuracy: <b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</span>
<span>Correct: <span class="correct">{n_correct}</span> / {len(completed)}</span>
<span>Pending: {n_pending}</span>
<span>Time: {self.total_time:.1f}s</span>
<span>Sampling: {sampling_str}</span>
<div class="label">Dataset</div><div class="value"><b>{self.dataset_type.upper()}</b></div>
<div class="label">Model</div><div class="value"><b>{self.model_name or 'N/A'}</b></div>
<div class="label">Accuracy</div><div class="value"><b>{accuracy:.1f}%</b> [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]</div>
<div class="label">Correct</div><div class="value"><span class="correct">{n_correct}</span> / {len(completed)}</div>
<div class="label">Pending</div><div class="value">{n_pending}</div>
<div class="label">Time</div><div class="value">{self.total_time:.1f}s</div>
<div class="label">Sampling</div><div class="value">{sampling_str}</div>
</div>
<div class="tabs">
<button class="tab-btn active" data-tab="detailed" onclick="switchTab(this)">Detailed</button>
<button class="tab-btn" data-tab="summary" onclick="switchTab(this)">Summary</button>
</div>
<div id="tab-detailed" class="tab-content active">
<table>
<thead>
<tr>
<th>ID</th>
<th></th>
<th>Gold</th>
<th>Answer</th>
<th>Tokens</th>
<th>T/s</th>
<th>Gen s</th>
<th>Server</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
</div>
<div id="tab-summary" class="tab-content">
<table class="summary-table">
<thead>
<tr>
<th>Problem</th>
<th>Runs</th>
<th>Correct</th>
<th colspan="3">Tokens</th>
<th colspan="3">T/s</th>
<th colspan="3">Gen s</th>
</tr>
<tr>
<th></th>
<th></th>
<th></th>
<th>min</th><th>avg</th><th>max</th>
<th>min</th><th>avg</th><th>max</th>
<th>min</th><th>avg</th><th>max</th>
</tr>
</thead>
<tbody>
{summary_rows_html}
</tbody>
</table>
</div>
<table>
<thead>
<tr>
<th>ID</th>
<th></th>
<th>Gold</th>
<th>Answer</th>
<th>Tokens</th>
<th>T/s</th>
<th>Gen s</th>
<th>Server</th>
</tr>
</thead>
<tbody>
{rows_html}
</tbody>
</table>
<script>
function toggleDetails(id) {{ document.getElementById('details-'+id).classList.toggle('open'); }}
function switchTab(btn) {{
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
btn.classList.add('active');
document.getElementById('tab-'+btn.dataset.tab).classList.add('active');
}}
</script>
</body>
</html>"""
@ -1062,12 +1172,19 @@ class Processor:
) -> TaskState:
question_text, prompt, expected = eval_state.get_case(i)
# Extract chunk_idx from task_id: "{dataset_type}_{chunk_idx:03d}_{index:03d}"
_parts = task_id.rsplit("_", 2)
chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
problem_idx = i
task_state = TaskState(
task_id=task_id,
prompt=prompt,
expected=expected,
question_text=question_text,
server_name=server_config.name
server_name=server_config.name,
chunk_idx=chunk_idx,
problem_idx=problem_idx,
)
try:
@ -1085,7 +1202,8 @@ class Processor:
eval_state.add_result(
task_id, prompt, expected, result, None,
{"finish_reason": finish_reason}, False, task_state.status,
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
chunk_idx, problem_idx,
)
eval_state.dump()
return task_state
@ -1108,7 +1226,8 @@ class Processor:
eval_state.add_result(
task_id, prompt, expected, result, answer,
grader_log, is_correct, "ok",
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name
tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name,
chunk_idx, problem_idx,
)
eval_state.dump()

View file

@ -65,34 +65,70 @@ def normalize_number(s: str) -> Optional[int]:
return int(match.group(0))
class AimeDataset:
def __init__(self, split: str = "train"):
def __init__(self, split: str = "train", dataset_type: str = "aime"):
self.split = split
self.dataset_type = dataset_type
self.questions: List[Dict] = []
self._load_dataset()
def _load_dataset(self):
print(f"Loading AIME dataset (split: {self.split})...")
def _get_question_text(self, question: Dict) -> str:
"""Get question text, handling different dataset field names."""
return question.get("problem", question.get("question", ""))
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
def _load_dataset(self):
if self.dataset_type == "aime":
print(f"Loading AIME dataset (split: {self.split})...")
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
elif self.dataset_type == "aime2025":
print(f"Loading AIME2025 dataset...")
ds_list = []
for config_name in ["AIME2025-I", "AIME2025-II"]:
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
ds_list.extend(ds)
ds = ds_list
else:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
self.questions = list(ds)
print(f"AIME dataset loaded: {len(self.questions)} questions")
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
def find_question(self, request_text: str) -> Optional[Dict]:
# Strip common template prefixes to get the actual question text
# Templates include things like "Solve the following math problem step by step..."
# The actual question usually follows a blank line or after the template instruction
cleaned = request_text
# Split on double newline and take the part that looks like the problem
parts = cleaned.split('\n\n')
if len(parts) > 1:
# Find the part that's longest (likely the actual problem text)
problem_parts = [p for p in parts if len(p.strip()) > 100]
if problem_parts:
cleaned = max(problem_parts, key=lambda x: len(x))
best_match = None
best_distance = -1
best_index = -1
for i, question in enumerate(self.questions):
question_text = question["problem"]
request_lower = request_text.lower()
question_text = self._get_question_text(question)
request_lower = cleaned.lower()
question_lower = question_text.lower()
# Check if question text is contained in the cleaned request
if question_lower in request_lower or request_lower in question_lower:
debug_log(f"DEBUG: Found substring match at index {i}")
return question
# Exact match
if question_lower == request_lower:
debug_log(f"DEBUG: Found exact match at index {i}")
@ -118,7 +154,7 @@ class AimeDataset:
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
return best_match
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
return None
def get_answer(self, question: Dict) -> str:
@ -134,15 +170,16 @@ class Simulator:
port: int = 8033,
host: str = "localhost",
success_rate: float = 0.8,
dataset_split: str = "train"
dataset_split: str = "train",
dataset_type: str = "aime"
):
self.port = port
self.host = host
self.success_rate = success_rate
self.dataset = AimeDataset(dataset_split)
self.dataset = AimeDataset(dataset_split, dataset_type)
self.eval_state = EvalState(
id="aime-2025",
tasks=["aime"],
id=dataset_type,
tasks=[dataset_type],
task_states={},
sampling_config={"temperature": 0, "max_tokens": 2048}
)
@ -159,6 +196,10 @@ class Simulator:
else:
response_text = self._generate_wrong_answer(question)
comp_tokens = random.randint(10000, 60000)
tps_gen = random.uniform(90.0, 110.0)
t_gen_ms = comp_tokens / tps_gen * 1000
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
@ -176,8 +217,12 @@ class Simulator:
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150
"completion_tokens": comp_tokens,
"total_tokens": 100 + comp_tokens
},
"timings": {
"predicted_ms": t_gen_ms,
"predicted_per_second": tps_gen
}
}
@ -218,6 +263,12 @@ class Simulator:
return response
class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/v1/models":
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
return
self._send_json({"error": "Not found"}, 404)
def do_POST(self):
if self.path != "/v1/chat/completions":
self._send_json({"error": "Not found"}, 404)
@ -280,6 +331,13 @@ def main():
default=0.8,
help="Success rate 0-1 (default: 0.8)"
)
parser.add_argument(
"--dataset",
type=str,
default="aime",
choices=["aime", "aime2025"],
help="Dataset type (default: aime)"
)
parser.add_argument(
"--dataset-split",
type=str,
@ -294,7 +352,8 @@ def main():
port=args.port,
host=args.host,
success_rate=args.success_rate,
dataset_split=args.dataset_split
dataset_split=args.dataset_split,
dataset_type=args.dataset
)
server = HTTPServer((args.host, args.port), RequestHandler)
@ -304,7 +363,7 @@ def main():
print("\n=== llama-server-simulator ===")
print(f"Server running on http://{args.host}:{args.port}")
print(f"Success rate: {args.success_rate}")
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
print("\nPress Ctrl+C to stop\n")
try:

View file

@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
return 8;
case GGML_TYPE_Q6_K:
return 2;
case GGML_TYPE_IQ4_NL:
return 8;
default:

View file

@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
char base[256];
char name[256];
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
// note: this is slower
//const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0;
const bool is_c4 = false;
snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
res.c4 = is_c4;
return res;
}

View file

@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne00, nth_max);
const int nk0 = (args.ne00 + nth - 1)/nth;
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
nk0 = ne00/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0*ne01, 256);
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (nrptg*nth > 256) {
nrptg--;
}
}
@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
const int nth = std::min(1024, ne0);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne0, nth_max);
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
return 1;
}

View file

@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl(
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1u) {
if (K > 1) {
const int target_slot = (int)t - shift;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl(
}
}
if (K == 1u) {
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32(
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32(
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32(
}
}
kernel void kernel_pad_f32(
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.ne00) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
return;
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = 0.0f;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;

View file

@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() {
return &guid;
}
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
uint64_t last_graph_uid;
};
struct ggml_backend_rpc_buffer_type_context {
std::string endpoint;
uint32_t device;
@ -211,7 +219,6 @@ struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
uint64_t last_graph_uid;
};
struct ggml_backend_rpc_buffer_context {
@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend);
ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid;
bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid;
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->last_graph_uid = cgraph->uid;
rpc_dev_ctx->last_graph_uid = cgraph->uid;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name,
/* .last_graph_uid = */ 0,
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
}
// device interface
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
};
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
std::string dev_name = "RPC" + std::to_string(dev_id);
std::string dev_desc = std::string(endpoint);
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc,
/* .last_graph_uid = */ 0,
};
ggml_backend_dev_t dev = new ggml_backend_device {

View file

@ -581,7 +581,8 @@ struct llm_graph_params {
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
(
(!ubatch.token && !other.ubatch.token) ||
(!ubatch.embd && !other.ubatch.embd)
(!ubatch.embd && !other.ubatch.embd) ||
(ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd)
);
// when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same

View file

@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
if (mem_recr->n_rs_seq > 0) {
// [TAG_RECURRENT_ROLLBACK_SPLITS]
// TODO: recurrent state rollback does not support equal splits
ubatch = balloc.split_seq(n_ubatch);
} else {
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
}
if (ubatch.n_tokens == 0) {

View file

@ -75,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
if (mem_recr->n_rs_seq > 0) {
// [TAG_RECURRENT_ROLLBACK_SPLITS]
// TODO: recurrent state rollback does not support equal splits
ubatch = balloc.split_seq(n_ubatch);
} else {
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
}
if (ubatch.n_tokens == 0) {

View file

@ -416,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
if (n_rs_seq > 0) {
// [TAG_RECURRENT_ROLLBACK_SPLITS]
// TODO: recurrent state rollback does not support equal splits
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}
}
if (ubatch.n_tokens == 0) {

View file

@ -72,6 +72,7 @@ public:
// number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups
uint32_t n_rs_seq = 0;
// per-seq rollback index
std::vector<uint32_t> rs_idx;

View file

@ -447,13 +447,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
return build_delta_net_chunking(q, k, v, g, b, s, il);
}
bool llm_build_delta_net_base::keep_rs() const {
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
return cparams.n_rs_seq > 0
&& n_seq_tokens > 1
&& (uint32_t) n_seq_tokens <= 1 + cparams.n_rs_seq;
}
ggml_tensor * llm_build_delta_net_base::build_conv_state(
llm_graph_input_rs * inp,
ggml_tensor * conv_states_all,
@ -461,12 +454,12 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
int64_t conv_kernel_size,
int64_t conv_channels,
int il) {
const auto * mctx_cur = inp->mctx;
const auto kv_head = mctx_cur->get_head();
const uint32_t mem_size = mctx_cur->get_size();
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const bool keep = keep_rs();
const auto * mctx_cur = inp->mctx;
const auto kv_head = mctx_cur->get_head();
const auto mem_size = mctx_cur->get_size();
const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
cb(conv_states, "conv_states", il);
@ -480,32 +473,52 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
cb(conv_input, "conv_input", il);
if (!keep) {
ggml_tensor * last_conv_states =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
cb(last_conv_states, "last_conv_states", il);
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
ggml_tensor * state_update_target =
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
const size_t row_size = ggml_row_size(conv_states_all->type, row_count);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
if (cparams.n_rs_seq == 0) {
const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0];
const int64_t s_slot = 0;
ggml_tensor * conv_state_last =
ggml_view_3d(ctx0, conv_input,
conv_kernel_size - 1, conv_channels, n_seqs,
conv_input->nb[1], conv_input->nb[2],
ggml_row_size(conv_input->type, s_idx));
cb(conv_state_last, "conv_state_last", il);
ggml_tensor * conv_state_update =
ggml_view_2d(ctx0, conv_states_all,
row_count, n_seqs, conv_states_all->nb[1],
(s_slot * mem_size + kv_head) * row_size);
cb(conv_state_update, "conv_state_update", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
} else {
const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
const size_t row_size = row_count * ggml_element_size(conv_states_all);
for (int64_t t = 1; t <= n_seq_tokens; ++t) {
const uint32_t slot = (uint32_t)(n_seq_tokens - t);
ggml_tensor * src =
ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs,
conv_input->nb[1], conv_input->nb[2],
t * ggml_element_size(conv_input));
ggml_tensor * dst =
ggml_view_2d(ctx0, conv_states_all, row_count, n_seqs,
conv_states_all->nb[1],
((size_t) slot * mem_size + kv_head) * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
// [TAG_RECURRENT_ROLLBACK_SPLITS]
// TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are
// inside the same ubatch. currently with `split_equal()` this is not correct
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
for (int64_t t = 1; t <= K; ++t) {
const int64_t s_idx = std::max<int64_t>(0, conv_input->ne[0] - conv_states->ne[0] - K + t);
const int64_t s_slot = K - t;
ggml_tensor * conv_state_last =
ggml_view_3d(ctx0, conv_input,
conv_kernel_size - 1, conv_channels, n_seqs,
conv_input->nb[1], conv_input->nb[2],
ggml_row_size(conv_input->type, s_idx));
ggml_tensor * conv_state_update =
ggml_view_2d(ctx0,
conv_states_all, row_count, n_seqs,
conv_states_all->nb[1],
(s_slot * mem_size + kv_head) * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
}
}
@ -531,7 +544,9 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
const int64_t n_seqs = s->ne[3];
const int64_t n_seq_tokens = q->ne[2];
if (!keep_rs()) {
const bool keep = cparams.n_rs_seq > 0;
if (!keep) {
auto attn_out = build_delta_net(q, k, v, g, b, s, il);
ggml_tensor * output = attn_out.first;
ggml_tensor * new_state = attn_out.second;
@ -547,14 +562,18 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
}
const int64_t D = S_v * S_v * H_v;
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
const int64_t K = cparams.n_rs_seq + 1;
// TODO: remove pad + simplify
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
if (n_seq_tokens > 1) {
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
} else {
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il);
}
const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs;
const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
@ -576,9 +595,11 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
ggml_row_size(gdn_out->type, S_v * S_v),
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
((size_t) cache_slot * mem_size + kv_head) * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
}

View file

@ -66,9 +66,6 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * s,
int il);
// true when speculative rollback is enabled and the batch fits in the rs cache
bool keep_rs() const;
// read conv state from cache, concat with qkv_mixed, write back (single slot or per-token)
// qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs)
ggml_tensor * build_conv_state(

View file

@ -496,7 +496,8 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
// The MTP block lives at the source file's original layer index.
// hparams.n_layer includes both main model layers and MTP layers. The MTP
// layer is stored immediately after the main layers in model.layers[].
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
const auto & layer = model.layers[il];

View file

@ -30,7 +30,7 @@ int main(int argc, char ** argv) {
if (!params.fit_params_print) {
const common_params_fit_status status = common_fit_params(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
params.verbosity >= LOG_LEVEL_DEBUG ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
if (status != COMMON_PARAMS_FIT_STATUS_SUCCESS) {
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
exit(1);

View file

@ -467,20 +467,26 @@ struct server_slot {
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
SLT_INF(*this,
"\n"
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second);
SLT_INF(*this,
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
t_token_generation, n_decoded, t_gen, n_gen_second);
SLT_INF(*this,
" total time = %10.2f ms / %5d tokens\n",
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
t_token_generation, n_decoded, t_gen, n_gen_second,
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
SLT_INF(*this,
" graphs reused = %10d\n",
llama_perf_context(ctx_tgt).n_reused);
if (n_draft_total > 0) {
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
SLT_CNT(*this,
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
SLT_INF(*this,
"draft acceptance = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total);
}
common_speculative_print_stats(spec);
@ -2583,9 +2589,9 @@ private:
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, pos_next - n_swa);
const auto pos_min_thold = std::max(0, pos_next - n_swa - 1);
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
if (n_past > 0 && n_past <= slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
@ -3885,6 +3891,7 @@ void server_routes::init_routes() {
{ "eos_token", meta->eos_token_str },
{ "build_info", meta->build_info },
{ "is_sleeping", queue_tasks.is_sleeping() },
{ "cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy },
};
if (params.use_jinja) {
if (!tmpl_tools.empty()) {

View file

@ -1165,6 +1165,7 @@ void server_models_routes::init_routes() {
// Deprecated: use ui_settings instead (kept for backward compat)
{"webui_settings", webui_settings},
{"build_info", std::string(llama_build_info())},
{"cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy},
});
return res;
}

2
tools/ui/.env.example Normal file
View file

@ -0,0 +1,2 @@
VITE_PUBLIC_APP_NAME='llama-ui'
# VITE_DEBUG='true'

2
tools/ui/.gitignore vendored
View file

@ -25,4 +25,4 @@ vite.config.ts.timestamp-*
*storybook.log
storybook-static
*.code-workspace
*.code-workspace

View file

@ -20,9 +20,7 @@ export default ts.config(
prettier,
...svelte.configs.prettier,
{
languageOptions: {
globals: { ...globals.browser, ...globals.node }
},
languageOptions: { globals: { ...globals.browser, ...globals.node } },
rules: {
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
@ -30,6 +28,7 @@ export default ts.config(
'svelte/no-at-html-tags': 'off',
// This app uses hash-based routing (#/) where resolve() from $app/paths does not apply
'svelte/no-navigation-without-resolve': 'off',
// Enforce empty line at end of file
'eol-last': 'error'
}

View file

@ -2307,9 +2307,9 @@
}
},
"node_modules/@sveltejs/kit": {
"version": "2.59.1",
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.59.1.tgz",
"integrity": "sha512-d8OON70AphLdDesuTIl//M2O6fRTIicX8aYv8vhCiYEhTTI2OboKqey0Hu1A4VFhqwgqtq0vKDmPFGkw8kKmgw==",
"version": "2.60.1",
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.60.1.tgz",
"integrity": "sha512-mQjlkNo+rJvpln7V2IGY2j99BqhcFbS4UN0AQNKNYfhBAFZTuCDAdW3a1sgf330mvtNvsBXn3HpAhcmvdJTcIQ==",
"dev": true,
"license": "MIT",
"dependencies": {
@ -2318,7 +2318,7 @@
"@types/cookie": "^0.6.0",
"acorn": "^8.14.1",
"cookie": "^0.6.0",
"devalue": "^5.6.4",
"devalue": "^5.8.1",
"esm-env": "^1.2.2",
"kleur": "^4.1.5",
"magic-string": "^0.30.5",
@ -4296,9 +4296,9 @@
}
},
"node_modules/devalue": {
"version": "5.6.4",
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.6.4.tgz",
"integrity": "sha512-Gp6rDldRsFh/7XuouDbxMH3Mx8GMCcgzIb1pDTvNyn8pZGQ22u+Wa+lGV9dQCltFQ7uVw0MhRyb8XDskNFOReA==",
"version": "5.8.1",
"resolved": "https://registry.npmjs.org/devalue/-/devalue-5.8.1.tgz",
"integrity": "sha512-4CXDYRBGqN+57wVJkuXBYmpAVUSg3L6JAQa/DFqm238G73E1wuyc/JhGQJzN7vUf/CMphYau2zXbfWzDR5aTEw==",
"license": "MIT"
},
"node_modules/devlop": {
@ -4856,12 +4856,12 @@
}
},
"node_modules/express-rate-limit": {
"version": "8.5.0",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.0.tgz",
"integrity": "sha512-XKhFohWaSBdVJNTi5TaHziqnPkv04I9UQV6q1Wy7Ui6GGQZVW12ojDFwqer14EvCXxjvPG0CyWXx7cAXpALB4Q==",
"version": "8.5.2",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.2.tgz",
"integrity": "sha512-5Kb34ipNX694DH48vN9irak1Qx30nb0PLYHXfJgw4YEjiC3ZEmZJhwOp+VfiCYwFzvFTdB9QkArYS5kXa2cx2A==",
"license": "MIT",
"dependencies": {
"ip-address": "10.1.0"
"ip-address": "^10.2.0"
},
"engines": {
"node": ">= 16"
@ -4909,9 +4909,9 @@
"license": "MIT"
},
"node_modules/fast-uri": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz",
"integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==",
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.2.tgz",
"integrity": "sha512-rVjf7ArG3LTk+FS6Yw81V1DLuZl1bRbNrev6Tmd/9RaroeeRRJhAt7jg/6YFxbvAQXUCavSoZhPPj6oOx+5KjQ==",
"funding": [
{
"type": "github",
@ -5541,9 +5541,9 @@
}
},
"node_modules/hono": {
"version": "4.12.14",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.14.tgz",
"integrity": "sha512-am5zfg3yu6sqn5yjKBNqhnTX7Cv+m00ox+7jbaKkrLMRJ4rAdldd1xPd/JzbBWspqaQv6RSTrgFN95EsfhC+7w==",
"version": "4.12.19",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.19.tgz",
"integrity": "sha512-xa3eYXYXx68XTT4hZ7dRzsXBhaq85ToSrlUJNoR0gwz/1Ap/CNwX47wfvV7pc/xWhjKVVkLT7zBJy8chhNguqQ==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"
@ -5722,9 +5722,9 @@
"license": "MIT"
},
"node_modules/ip-address": {
"version": "10.1.0",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.1.0.tgz",
"integrity": "sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==",
"version": "10.2.0",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz",
"integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==",
"license": "MIT",
"engines": {
"node": ">= 12"
@ -6008,9 +6008,9 @@
}
},
"node_modules/katex": {
"version": "0.16.22",
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.22.tgz",
"integrity": "sha512-XCHRdUw4lf3SKBaJe4EvgqIuWwkPSo9XoeO8GjQW94Bp7TWv9hNhzZjZ+OH9yf1UmLygb7DIT5GSFQiyt16zYg==",
"version": "0.16.47",
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.47.tgz",
"integrity": "sha512-Eeo8Ys1doU1z+x8AZsPpQu+p/QcZBI5PeOo7QGQdy2x2m0MU/hYagBbGOmXwr5KVbEfVuWv9LpnQWeehogurjg==",
"dev": true,
"funding": [
"https://opencollective.com/katex",
@ -9245,9 +9245,9 @@
}
},
"node_modules/svelte": {
"version": "5.55.1",
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.1.tgz",
"integrity": "sha512-QjvU7EFemf6mRzdMGlAFttMWtAAVXrax61SZYHdkD6yoVGQ89VeyKfZD4H1JrV1WLmJBxWhFch9H6ig/87VGjw==",
"version": "5.55.7",
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.55.7.tgz",
"integrity": "sha512-ymI5ykLPwIHW839E053FQbI1G+jnRFJEw3Kv5Y4njixVWywQBx+NUFpkkKyk5LIb36Fg9DVXSYpqiGekLD0hyw==",
"license": "MIT",
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
@ -9259,7 +9259,7 @@
"aria-query": "5.3.1",
"axobject-query": "^4.1.0",
"clsx": "^2.1.1",
"devalue": "^5.6.4",
"devalue": "^5.8.1",
"esm-env": "^1.2.1",
"esrap": "^2.2.4",
"is-reference": "^3.0.3",
@ -10606,9 +10606,9 @@
"license": "ISC"
},
"node_modules/ws": {
"version": "8.18.3",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz",
"integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==",
"version": "8.20.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.20.1.tgz",
"integrity": "sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w==",
"dev": true,
"license": "MIT",
"engines": {

View file

@ -1,6 +1,7 @@
@import 'tailwindcss';
@source ".";
@source '.';
@plugin '@tailwindcss/forms';
@plugin '@tailwindcss/typography';
@import 'tw-animate-css';
@custom-variant dark (&:is(.dark *));

View file

@ -15,6 +15,7 @@
{#if videoSrc}
<video controls class="mb-4 w-full" src={videoSrc}>
<track kind="captions" src="" />
Your browser does not support the video element.
</video>
{:else}

View file

@ -7,7 +7,6 @@
import { activeMessages } from '$lib/stores/conversations.svelte';
interface Props {
currentModel?: string;
disabled?: boolean;
forceForegroundText?: boolean;
hasAudioModality?: boolean;
@ -20,7 +19,6 @@
}
let {
currentModel,
disabled = false,
forceForegroundText = false,
hasAudioModality = $bindable(false),
@ -41,14 +39,28 @@
let lastSyncedConversationModel: string | null = null;
let selectorModel = $derived(conversationModel ?? modelsStore.selectedModelName ?? null);
$effect(() => {
if (conversationModel && conversationModel !== lastSyncedConversationModel) {
lastSyncedConversationModel = conversationModel;
if (modelOptions().some((m) => m.model === conversationModel)) {
modelsStore.selectedModelName = conversationModel;
modelsStore.selectModelByName(conversationModel);
} else {
modelsStore.selectedModelName = null;
modelsStore.clearSelection();
}
modelsStore.selectModelByName(conversationModel);
} else if (isRouter && !modelsStore.selectedModelId && modelsStore.loadedModelIds.length > 0) {
lastSyncedConversationModel = conversationModel;
} else if (
isRouter &&
!modelsStore.selectedModelId &&
modelsStore.loadedModelIds.length > 0 &&
activeMessages().length > 0 &&
!conversationModel
) {
lastSyncedConversationModel = null;
// auto-select the first loaded model only when nothing is selected yet
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) modelsStore.selectModelById(first.id);
@ -151,7 +163,7 @@
<ModelsSelectorSheet
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
currentModel={selectorModel}
{forceForegroundText}
{useGlobalSelection}
/>
@ -159,7 +171,7 @@
<ModelsSelectorDropdown
disabled={disabled || isOffline}
bind:this={selectorModelRef}
{currentModel}
currentModel={selectorModel}
{forceForegroundText}
{useGlobalSelection}
/>

View file

@ -162,7 +162,7 @@
return;
}
if (import.meta.env.DEV) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log('[ChatFormPickerMcpPrompts] Fetching completions for:', {
serverName: selectedPrompt.serverName,
promptName: selectedPrompt.name,
@ -181,7 +181,7 @@
value
);
if (import.meta.env.DEV) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log('[ChatFormPickerMcpPrompts] Autocomplete result:', {
argName,
value,

View file

@ -1,20 +1,20 @@
<script lang="ts">
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
import { Trash2 } from '@lucide/svelte';
import { afterNavigate } from '$app/navigation';
import { page } from '$app/state';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import {
ChatScreenForm,
ChatMessages,
ChatScreenDragOverlay,
ChatScreenProcessingInfo,
ChatScreenActionScrollDown,
DialogEmptyFileAlert,
DialogFileUploadError,
DialogChatError,
ServerLoadingSplash,
DialogConfirmation
DialogConfirmation,
ChatScreenServerError
} from '$lib/components/app';
import * as Alert from '$lib/components/ui/alert';
import { setProcessingInfoContext } from '$lib/contexts';
import { ErrorDialogType } from '$lib/enums';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
@ -34,11 +34,12 @@
activeConversation
} from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { serverLoading, serverError, serverStore, isRouterMode } from '$lib/stores/server.svelte';
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
import { onMount } from 'svelte';
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
let { showCenteredEmpty = false } = $props();
@ -67,6 +68,8 @@
let showEmptyFileDialog = $state(false);
let processingInfoVisible = $state(false);
let emptyFileNames = $state<string[]>([]);
let initialMessage = $state('');
@ -174,6 +177,10 @@
showDeleteDialog = false;
}
function handleProcessingInfoVisibility(visible: boolean) {
processingInfoVisible = visible;
}
function handleDragEnter(event: DragEvent) {
event.preventDefault();
@ -338,7 +345,9 @@
});
function handleMessagesReady() {
if (!disableAutoScroll && !autoScroll.userScrolledUp) {
if (disableAutoScroll) return;
if (!autoScroll.userScrolledUp) {
requestAnimationFrame(() => {
autoScroll.scrollToBottom('instant');
});
@ -392,59 +401,32 @@
{#if !isEmpty}
<ChatMessages
messages={activeMessages()}
onMessagesReady={handleMessagesReady}
onUserAction={() => {
autoScroll.enable();
if (!autoScroll.userScrolledUp) {
autoScroll.scrollToBottom();
}
}}
onMessagesReady={handleMessagesReady}
/>
{/if}
<div
class="pointer-events-none {isEmpty
? 'absolute bottom-[calc(50dvh-7rem)]'
: 'sticky bottom-4'} right-4 left-4 mt-auto pt-16 transition-all duration-200"
class={[
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
]}
>
{#if isEmpty}
<div class="mb-8 px-4 text-center" use:fadeInView={{ duration: 300 }}>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
<ChatScreenGreeting {isEmpty} />
<p class="text-muted-foreground md:text-lg">
{serverStore.props?.modalities?.audio
? 'Record audio, type a message '
: 'Type a message'} or upload files to get started
</p>
</div>
{/if}
<ChatScreenActionScrollDown
container={chatScrollContainer}
hasProcessingInfoVisible={processingInfoVisible}
/>
{#if page.params.id}
<ChatScreenProcessingInfo />
{/if}
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
{#if hasPropsError}
<div
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
use:fadeInView={{ y: 10, duration: 250 }}
>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
<Alert.Title class="flex items-center justify-between">
<span>Server unavailable</span>
<button
onclick={() => serverStore.fetch()}
disabled={isServerLoading}
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
>
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
{isServerLoading ? 'Retrying...' : 'Retry'}
</button>
</Alert.Title>
<Alert.Description>{serverError()}</Alert.Description>
</Alert.Root>
</div>
{/if}
<ChatScreenServerError />
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
<ChatScreenForm

View file

@ -0,0 +1,61 @@
<script lang="ts">
import { ArrowDown } from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
interface Props {
container: HTMLDivElement | undefined;
hasProcessingInfoVisible: boolean;
}
let { container, hasProcessingInfoVisible }: Props = $props();
let show = $state(false);
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
function checkVisibility() {
if (!container) return;
const { scrollTop, scrollHeight, clientHeight } = container;
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
show = distanceFromBottom > clientHeight * 0.5;
}
function scrollToBottom() {
if (container) {
container.scrollTo({
top: container.scrollHeight,
behavior: 'smooth'
});
}
}
$effect(() => {
const c = container;
if (c) {
c.addEventListener('scroll', checkVisibility);
checkVisibility();
return () => {
c.removeEventListener('scroll', checkVisibility);
};
}
});
</script>
<div
class="pointer-events-{show
? 'auto'
: 'none'} relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center"
>
<Button
onclick={scrollToBottom}
variant="secondary"
size="icon"
class="pointer-events-all absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
? 1
: 0};"
aria-label="Scroll to bottom"
>
<ArrowDown class="h-4 w-4" />
</Button>
</div>

View file

@ -0,0 +1,25 @@
<script lang="ts">
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import { serverStore } from '$lib/stores/server.svelte';
interface Props {
isEmpty: boolean;
}
let { isEmpty = false }: Props = $props();
</script>
<div
class={[
'pointer-events-none mb-4 hidden px-4 text-center',
isEmpty && 'pointer-events-auto block!'
]}
use:fadeInView={{ duration: 300 }}
>
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
<p class="text-muted-foreground md:text-lg">
{serverStore.props?.modalities?.audio ? 'Record audio, type a message ' : 'Type a message'} or upload
files to get started
</p>
</div>

View file

@ -6,6 +6,7 @@
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
import { config } from '$lib/stores/settings.svelte';
import { getProcessingInfoContext } from '$lib/contexts';
import { page } from '$app/state';
const processingState = useProcessingState();
const processingInfoCtx = getProcessingInfoContext();
@ -16,6 +17,14 @@
let isStreaming = $derived(isChatStreaming());
let processingDetails = $derived(processingState.getTechnicalDetails());
let processingVisible = $derived(processingDetails.length > 0);
let { onVisibilityChange }: { onVisibilityChange?: (visible: boolean) => void } = $props();
$effect(() => {
onVisibilityChange?.(processingVisible);
});
$effect(() => {
const conversation = activeConversation();
@ -60,9 +69,12 @@
</script>
<div
class={['chat-processing-info-container pointer-events-none', showProcessingInfo && 'visible']}
class={[
'chat-processing-info-container pointer-events-none relative',
page.params.id && showProcessingInfo && 'visible'
]}
>
<div class="chat-processing-info-content">
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
{#each processingDetails as detail (detail)}
<span class="chat-processing-info-detail pointer-events-auto backdrop-blur-sm">{detail}</span>
{/each}

View file

@ -0,0 +1,34 @@
<script lang="ts">
import { AlertTriangle, RefreshCw } from '@lucide/svelte';
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
import * as Alert from '$lib/components/ui/alert';
import { serverError, serverLoading, serverStore } from '$lib/stores/server.svelte';
let hasError = $derived(!!serverError());
</script>
{#if hasError}
<div
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
use:fadeInView={{ y: 10, duration: 250 }}
>
<Alert.Root variant="destructive">
<AlertTriangle class="h-4 w-4" />
<Alert.Title class="flex items-center justify-between">
<span>Server unavailable</span>
<button
onclick={() => serverStore.fetch()}
disabled={serverLoading()}
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
>
<RefreshCw class="h-3 w-3 {serverLoading() ? 'animate-spin' : ''}" />
{serverLoading() ? 'Retrying...' : 'Retry'}
</button>
</Alert.Title>
<Alert.Description>{serverError()}</Alert.Description>
</Alert.Root>
</div>
{/if}

View file

@ -667,3 +667,17 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
* Only visible when `isCurrentConversationLoading` is true.
*/
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
/**
* Scroll-to-bottom action button. Displays a floating button when the user
* has scrolled up more than half a viewport height from the bottom.
* Takes the chat container element as a prop to manage scroll state internally.
*/
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
/**
* Server error alert displayed when the server is unreachable.
* Shows the error message with a retry button.
* Rendered inside ChatScreen when `serverError` store has a value.
*/
export { default as ChatScreenServerError } from './ChatScreen/ChatScreenServerError.svelte';

View file

@ -28,7 +28,7 @@
SETTINGS_KEYS
} from '$lib/constants';
import { ColorMode, UrlProtocol } from '$lib/enums';
import { FileTypeText } from '$lib/enums/files';
import { FileTypeText } from '$lib/enums/files.enums';
import { highlightCode, detectIncompleteCodeBlock, type IncompleteCodeBlock } from '$lib/utils';
import '$styles/katex-custom.scss';
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';

View file

@ -17,7 +17,7 @@
} from '$lib/constants';
import { RouterService } from '$lib/services/router.service';
import { setMode } from 'mode-watcher';
import { ColorMode } from '$lib/enums/ui';
import { ColorMode } from '$lib/enums/ui.enums';
import { fade } from 'svelte/transition';
import { goto } from '$app/navigation';
import { page } from '$app/state';

View file

@ -6,7 +6,7 @@
import * as Select from '$lib/components/ui/select';
import { Textarea } from '$lib/components/ui/textarea';
import { SETTING_CONFIG_INFO, SETTINGS_KEYS } from '$lib/constants';
import { SettingsFieldType } from '$lib/enums/settings';
import { SettingsFieldType } from '$lib/enums/settings.enums';
import { settingsStore } from '$lib/stores/settings.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { modelsStore, selectedModelName, propsCacheVersion } from '$lib/stores/models.svelte';

View file

@ -2,7 +2,7 @@ import { Zap, Globe, Radio } from '@lucide/svelte';
import { MCPTransportType } from '$lib/enums';
import type { ClientCapabilities, Implementation } from '$lib/types';
import type { Component } from 'svelte';
import { MimeTypeImage } from '$lib/enums/files';
import { MimeTypeImage } from '$lib/enums/files.enums';
export const DEFAULT_CLIENT_VERSION = '1.0.0';
export const MCP_CLIENT_NAME = 'llama-ui-mcp';

View file

@ -1,5 +1,5 @@
import { ColorMode } from '$lib/enums/ui';
import { SettingsFieldType } from '$lib/enums/settings';
import { ColorMode } from '$lib/enums/ui.enums';
import { SettingsFieldType } from '$lib/enums/settings.enums';
import { SyncableParameterType } from '$lib/enums';
import {
Funnel,

View file

@ -18,7 +18,7 @@ import {
MimeTypeApplication,
MimeTypeText
} from '$lib/enums';
import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files';
import { FileExtensionVideo, FileTypeVideo } from '$lib/enums/files.enums';
// File type configuration using enums
export const AUDIO_FILE_TYPES = {

View file

@ -1,4 +1,4 @@
import { ToolSource } from '$lib/enums/tools';
import { ToolSource } from '$lib/enums/tools.enums';
export const TOOL_GROUP_LABELS = {
[ToolSource.BUILTIN]: 'Built-in',

View file

@ -4,9 +4,9 @@ export {
AttachmentItemEnabledWhen,
AttachmentAction,
AttachmentItemVisibleWhen
} from './attachment';
} from './attachment.enums';
export { AgenticSectionType, ToolCallType } from './agentic';
export { AgenticSectionType, ToolCallType } from './agentic.enums';
export {
ChatMessageStatsView,
@ -17,7 +17,7 @@ export {
MessageType,
PdfViewMode,
ReasoningFormat
} from './chat';
} from './chat.enums';
export {
FileTypeCategory,
@ -38,7 +38,7 @@ export {
MimeTypeImage,
MimeTypeText,
SpecialFileType
} from './files';
} from './files.enums';
export {
MCPConnectionPhase,
@ -48,16 +48,16 @@ export {
MCPContentType,
MCPRefType,
JsonSchemaType
} from './mcp';
} from './mcp.enums';
export { ModelModality } from './model';
export { ModelModality } from './model.enums';
export { ServerRole, ServerModelStatus } from './server';
export { ServerRole, ServerModelStatus } from './server.enums';
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings';
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui';
export { ColorMode, HtmlInputType, McpPromptVariant, TooltipSide, UrlProtocol } from './ui.enums';
export { KeyboardKey } from './keyboard';
export { KeyboardKey } from './keyboard.enums';
export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools';
export { ToolSource, ToolPermissionDecision, ToolResponseField } from './tools.enums';

View file

@ -100,6 +100,14 @@ export class AutoScrollController {
this._autoScrollEnabled = true;
}
/**
* Resets scroll state when switching conversations.
*/
resetScrollState(): void {
this._userScrolledUp = false;
this._autoScrollEnabled = true;
}
/**
* Starts the auto-scroll interval for continuous scrolling during streaming.
*/

View file

@ -66,7 +66,6 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
const serverModel = $derived(singleModelName());
const currentModel = $derived(opts.currentModel());
const useGlobalSelection = $derived(opts.useGlobalSelection?.() ?? false);
const onModelChange = $derived(opts.onModelChange?.());
const isHighlightedCurrentModelActive = $derived.by(() => {
@ -128,6 +127,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
if (onModelChange) {
const result = await onModelChange(option.id, option.model);
if (result === false) {
shouldCloseMenu = false;
}
@ -142,12 +142,14 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
const textarea = document.querySelector<HTMLTextAreaElement>(
'[data-slot="chat-form"] textarea'
);
textarea?.focus();
});
}
if (!onModelChange && isRouter && !modelsStore.isModelLoaded(option.model)) {
isLoadingModel = true;
modelsStore
.loadModel(option.model)
.catch((error) => console.error('Failed to load model:', error))
@ -158,6 +160,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
function getDisplayOption(): ModelOption | undefined {
if (!isRouter) {
const displayModel = serverModel || currentModel;
if (displayModel) {
return {
id: serverModel ? 'current' : 'offline-current',
@ -166,12 +169,8 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
capabilities: []
};
}
return undefined;
}
if (useGlobalSelection && activeId) {
const selected = options.find((option) => option.id === activeId);
if (selected) return selected;
return undefined;
}
if (currentModel) {
@ -183,6 +182,7 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
capabilities: []
};
}
return options.find((option) => option.model === currentModel);
}
@ -197,57 +197,77 @@ export function useModelsSelector(opts: UseModelsSelectorOptions): UseModelsSele
get options() {
return options;
},
get loading() {
return loading;
},
get updating() {
return updating;
},
get activeId() {
return activeId;
},
get isRouter() {
return isRouter;
},
get serverModel() {
return serverModel;
},
get isHighlightedCurrentModelActive() {
return isHighlightedCurrentModelActive;
},
get isCurrentModelInCache() {
return isCurrentModelInCache;
},
get filteredOptions() {
return filteredOptions;
},
get groupedFilteredOptions() {
return groupedFilteredOptions;
},
get isLoadingModel() {
return isLoadingModel;
},
get searchTerm() {
return searchTerm;
},
get showModelDialog() {
return showModelDialog;
},
get infoModelId() {
return infoModelId;
},
setSearchTerm(value: string) {
searchTerm = value;
},
setShowModelDialog(value: boolean) {
showModelDialog = value;
},
handleInfoClick,
handleSelect,
handleOpenChange,
isFavorite(model: string) {
return modelsStore.favoriteModelIds.has(model);
},
getDisplayOption
};
}

View file

@ -392,7 +392,7 @@ export class MCPService {
const url = new URL(config.url);
if (import.meta.env.DEV) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService] Creating WebSocket transport for ${url.href}`);
}
@ -413,12 +413,12 @@ export class MCPService {
onLog
);
if (useProxy && import.meta.env.DEV) {
if (useProxy && import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService] Using CORS proxy for ${config.url} -> ${url.href}`);
}
try {
if (import.meta.env.DEV) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService] Creating StreamableHTTP transport for ${url.href}`);
}
@ -520,7 +520,7 @@ export class MCPService {
)
);
if (import.meta.env.DEV) {
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService][${serverName}] Creating transport...`);
}
@ -560,6 +560,22 @@ export class MCPService {
);
const runtimeErrorHandler = (error: Error) => {
// Ignore errors that are expected when the SDK's transport is closed,
// or when connecting to servers that don't support SSE (stateless-only
// endpoints returning 405). The SDK wraps the original AbortError in
// a new Error with the message "SSE stream disconnected: AbortError",
// and also produces "Cannot cancel a stream locked by a reader".
// DOMException is thrown by the browser when aborting fetch requests.
const msg = error.message || String(error);
if (
error.name === 'AbortError' ||
error instanceof DOMException ||
msg.includes('SSE stream disconnected') ||
msg.includes('stream locked by a reader') ||
msg.includes('The operation was aborted')
) {
return;
}
console.error(`[MCPService][${serverName}] Protocol error after initialize:`, error);
};
@ -658,7 +674,10 @@ export class MCPService {
this.createLog(MCPConnectionPhase.LISTING_TOOLS, 'Listing available tools...')
);
console.log(`[MCPService][${serverName}] Connected, listing tools...`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService][${serverName}] Connected, listing tools...`);
}
const tools = await this.listTools({
client,
transport,
@ -680,10 +699,11 @@ export class MCPService {
`Connection established with ${tools.length} tools (${connectionTimeMs}ms)`
)
);
console.log(
`[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms`
);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(
`[MCPService][${serverName}] Initialization complete with ${tools.length} tools in ${connectionTimeMs}ms`
);
}
return {
client,
@ -709,9 +729,22 @@ export class MCPService {
* @param connection - The active MCP connection to close
*/
static async disconnect(connection: MCPConnection): Promise<void> {
console.log(`[MCPService][${connection.serverName}] Disconnecting...`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService][${connection.serverName}] Disconnecting...`);
}
try {
// Prevent reconnection on voluntary disconnect
// Terminate the session first for streamable-http transports to cleanly
// close streams, matching the inspector's disconnect flow.
if (connection.transport instanceof StreamableHTTPClientTransport) {
await connection.transport.terminateSession();
}
// Clear error handlers before closing to prevent noise from expected
// abort errors during shutdown. The inspector avoids this entirely
// by not setting onerror, but since we use it for protocol logging,
// we must clear it before disconnect.
connection.client.onerror = undefined;
if (connection.transport.onclose) {
connection.transport.onclose = undefined;
}
@ -1078,7 +1111,9 @@ export class MCPService {
try {
await connection.client.unsubscribeResource({ uri });
console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(`[MCPService][${connection.serverName}] Unsubscribed from resource: ${uri}`);
}
} catch (error) {
console.error(
`[MCPService][${connection.serverName}] Failed to unsubscribe from resource:`,

View file

@ -119,7 +119,8 @@ const localStorageMigration: Migration = {
// Only migrate if new key doesn't already exist
const newValue = localStorage.getItem(newKey);
if (newValue !== null) {
console.log(`[Migration] localStorage: ${newKey} already exists, skipping`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] localStorage: ${newKey} already exists, skipping`);
continue;
}
@ -127,9 +128,11 @@ const localStorageMigration: Migration = {
if (oldValue !== null) {
localStorage.setItem(newKey, oldValue);
// Keep old key for downgrade compatibility - DO NOT DELETE
console.log(
`[Migration] localStorage: copied ${deprecatedKey}${newKey} (preserved old)`
);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
console.log(
`[Migration] localStorage: copied ${deprecatedKey}${newKey} (preserved old)`
);
}
}
}
}
@ -146,7 +149,8 @@ const idxdbMigration: Migration = {
async run(): Promise<void> {
const oldDbNames = await Dexie.getDatabaseNames();
if (!oldDbNames.includes(DB_APP_NAME_DEPRECATED)) {
console.log('[Migration] IndexedDB: no old database found, skipping');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] IndexedDB: no old database found, skipping');
return;
}
@ -155,11 +159,13 @@ const idxdbMigration: Migration = {
newDb.version(1).stores(IDXDB_STORES);
const existingConvs = await newDb.table(IDXDB_TABLES.conversations).count();
if (existingConvs > 0) {
console.log('[Migration] IndexedDB: new database already has data, skipping');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] IndexedDB: new database already has data, skipping');
return;
}
console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] IndexedDB: copying from', DB_APP_NAME_DEPRECATED);
const oldDb = new Dexie(DB_APP_NAME_DEPRECATED);
oldDb.version(1).stores(IDXDB_STORES);
@ -169,15 +175,18 @@ const idxdbMigration: Migration = {
if (conversations.length > 0) {
await newDb.table(IDXDB_TABLES.conversations).bulkAdd(conversations);
console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] IndexedDB: copied ${conversations.length} conversations`);
}
if (messages.length > 0) {
await newDb.table(IDXDB_TABLES.messages).bulkAdd(messages);
console.log(`[Migration] IndexedDB: copied ${messages.length} messages`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] IndexedDB: copied ${messages.length} messages`);
}
// Non-destructive: DO NOT delete old database - keep for downgrade compatibility
console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] IndexedDB: preserved old database for downgrade compatibility');
}
};
@ -419,7 +428,8 @@ const legacyMessageMigration: Migration = {
}
}
console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] Legacy messages: migrated ${migratedCount} messages`);
}
};
@ -434,7 +444,8 @@ const themeMigration: Migration = {
async run(): Promise<void> {
const legacyTheme = localStorage.getItem('theme');
if (legacyTheme === null) {
console.log('[Migration] Theme: no legacy theme key found, skipping');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] Theme: no legacy theme key found, skipping');
return;
}
@ -443,7 +454,8 @@ const themeMigration: Migration = {
const config = configRaw ? JSON.parse(configRaw) : {};
if (SETTINGS_KEYS.THEME in config) {
console.log('[Migration] Theme: config already has theme, skipping');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] Theme: config already has theme, skipping');
return;
}
@ -451,7 +463,8 @@ const themeMigration: Migration = {
localStorage.setItem(CONFIG_LOCALSTORAGE_KEY, JSON.stringify(config));
// Non-destructive: DO NOT delete legacy theme key - keep for downgrade compatibility
console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] Theme: copied standalone theme to config (preserved old key)`);
}
};
@ -491,7 +504,8 @@ export const MigrationService = {
*/
resetState(): void {
localStorage.removeItem(MIGRATION_STATE_KEY);
console.log('[Migration] State reset - all migrations will run again');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] State reset - all migrations will run again');
},
/**
@ -500,25 +514,30 @@ export const MigrationService = {
*/
async runAllMigrations(): Promise<void> {
const state = getMigrationState();
console.log('[Migration] Starting migration run, state:', state);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] Starting migration run, state:', state);
for (const migration of migrations) {
if (isMigrationCompleted(migration.id)) {
console.log(`[Migration] ${migration.id}: already completed, skipping`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] ${migration.id}: already completed, skipping`);
continue;
}
try {
console.log(`[Migration] ${migration.id}: running...`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] ${migration.id}: running...`);
await migration.run();
markMigrationCompleted(migration.id);
console.log(`[Migration] ${migration.id}: completed successfully`);
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log(`[Migration] ${migration.id}: completed successfully`);
} catch (error) {
console.error(`[Migration] ${migration.id}: failed`, error);
markMigrationFailed(migration.id);
}
}
console.log('[Migration] All migrations complete');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] All migrations complete');
}
};

View file

@ -20,11 +20,11 @@
*/
import { browser } from '$app/environment';
import { base } from '$app/paths';
import { SETTINGS_KEYS } from '$lib/constants';
import { MCPService } from '$lib/services/mcp.service';
import { config, settingsStore } from '$lib/stores/settings.svelte';
import { mcpResourceStore } from '$lib/stores/mcp-resources.svelte';
import { serverStore } from '$lib/stores/server.svelte';
import { mode } from 'mode-watcher';
import {
parseMcpServerSettings,
@ -43,7 +43,6 @@ import {
ToolCallType
} from '$lib/enums';
import {
CORS_PROXY_ENDPOINT,
DEFAULT_CACHE_TTL_MS,
DEFAULT_MCP_CONFIG,
EXPECTED_THEMED_ICON_PAIR_COUNT,
@ -86,7 +85,6 @@ class MCPStore {
private _toolCount = $state(0);
private _connectedServers = $state<string[]>([]);
private _healthChecks = $state<Record<string, HealthCheckState>>({});
private _proxyAvailable = $state(false);
private connections = new Map<string, MCPConnection>();
private toolsIndex = new Map<string, string>();
@ -96,27 +94,8 @@ class MCPStore {
private initPromise: Promise<boolean> | null = null;
private activeFlowCount = 0;
constructor() {
if (browser) {
this.probeProxy();
}
}
/**
* Probes the CORS proxy endpoint to determine availability.
* The endpoint is only registered when llama-server runs with --ui-mcp-proxy.
*/
async probeProxy(): Promise<void> {
try {
const response = await fetch(`${base}${CORS_PROXY_ENDPOINT}`, { method: 'HEAD' });
this._proxyAvailable = response.status !== 404;
} catch {
this._proxyAvailable = false;
}
}
get isProxyAvailable(): boolean {
return this._proxyAvailable;
return serverStore.props?.cors_proxy_enabled ?? false;
}
/**

View file

@ -3,7 +3,7 @@ import { toast } from 'svelte-sonner';
import { ServerModelStatus, ModelModality } from '$lib/enums';
import { ModelsService } from '$lib/services/models.service';
import { PropsService } from '$lib/services/props.service';
import { serverStore } from '$lib/stores/server.svelte';
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
import { TTLCache } from '$lib/utils';
import {
MODEL_PROPS_CACHE_TTL_MS,
@ -14,14 +14,7 @@ import {
import { conversationsStore } from '$lib/stores/conversations.svelte';
/**
* modelsStore - Reactive store for model management in both MODEL and ROUTER modes
*
* This store manages:
* - Available models list
* - Selected model for new conversations
* - Loaded models tracking (ROUTER mode)
* - Model usage tracking per conversation
* - Automatic unloading of unused models
* modelsStore - Reactive store for model management in both MODEL and ROUTER modes.
*
* **Architecture & Relationships:**
* - **ModelsService**: Stateless service for model API communication
@ -31,14 +24,8 @@ import { conversationsStore } from '$lib/stores/conversations.svelte';
*
* **API Inconsistency Workaround:**
* In MODEL mode, `/props` returns modalities for the single model.
* In ROUTER mode, `/props` has no modalities - must use `/props?model=<id>` per model.
* In ROUTER mode, `/props` has no modalities must use `/props?model=<id>` per model.
* This store normalizes this behavior so consumers don't need to know the server mode.
*
* **Key Features:**
* - **MODEL mode**: Single model, always loaded
* - **ROUTER mode**: Multi-model with load/unload capability
* - **Auto-unload**: Automatically unloads models not used by any conversation
* - **Lazy loading**: ensureModelLoaded() loads models on demand
*/
class ModelsStore {
/**
@ -57,8 +44,8 @@ class ModelsStore {
selectedModelId = $state<string | null>(null);
selectedModelName = $state<string | null>(null);
// dedup concurrent fetch() callers, all awaiters share the same inflight promise
// without this, ?model=<name> URL handler raced an in-progress fetch and saw an empty list
// Dedup concurrent fetch() callers — all awaiters share the same inflight promise.
// Without this, ?model=<name> URL handler races an in-progress fetch and sees an empty list.
private inflightFetch: Promise<void> | null = null;
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
@ -67,9 +54,9 @@ class ModelsStore {
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
/**
* Model-specific props cache with TTL
* Key: modelId, Value: props data including modalities
* TTL: 10 minutes - props don't change frequently
* Model-specific props cache with TTL.
* Key: modelId, Value: props data including modalities.
* TTL: 10 minutes props don't change frequently.
*/
private modelPropsCache = new TTLCache<string, ApiLlamaCppServerProps>({
ttlMs: MODEL_PROPS_CACHE_TTL_MS,
@ -78,7 +65,7 @@ class ModelsStore {
private modelPropsFetching = $state<Set<string>>(new Set());
/**
* Version counter for props cache - used to trigger reactivity when props are updated
* Version counter for props cache used to trigger reactivity when props are updated.
*/
propsCacheVersion = $state(0);
@ -92,7 +79,7 @@ class ModelsStore {
get selectedModel(): ModelOption | null {
if (!this.selectedModelId) return null;
return this.models.find((model) => model.id === this.selectedModelId) ?? null;
return this.models.find((m) => m.id === this.selectedModelId) ?? null;
}
get loadedModelIds(): string[] {
@ -117,7 +104,7 @@ class ModelsStore {
* In ROUTER mode, returns null (model is per-conversation).
*/
get singleModelName(): string | null {
if (serverStore.isRouterMode) return null;
if (isRouterMode()) return null;
const props = serverStore.props;
if (props?.model_alias) return props.model_alias;
@ -126,6 +113,11 @@ class ModelsStore {
return props.model_path.split(/(\\|\/)/).pop() || null;
}
get selectedModelContextSize(): number | null {
if (!this.selectedModelName) return null;
return this.getModelContextSize(this.selectedModelName);
}
/**
*
*
@ -134,10 +126,6 @@ class ModelsStore {
*
*/
/**
* Get modalities for a specific model
* Returns cached modalities from model props
*/
getModelModalities(modelId: string): ModelModalities | null {
const model = this.models.find((m) => m.model === modelId || m.id === modelId);
if (model?.modalities) {
@ -146,46 +134,29 @@ class ModelsStore {
const props = this.modelPropsCache.get(modelId);
if (props?.modalities) {
return {
vision: props.modalities.vision ?? false,
audio: props.modalities.audio ?? false,
video: props.modalities.video ?? false
};
return this.buildModalities(props.modalities);
}
return null;
}
/**
* Check if a model supports vision modality
*/
modelSupportsVision(modelId: string): boolean {
return this.getModelModalities(modelId)?.vision ?? false;
}
/**
* Check if a model supports audio modality
*/
modelSupportsAudio(modelId: string): boolean {
return this.getModelModalities(modelId)?.audio ?? false;
}
/**
* Check if a model supports video modality
*/
modelSupportsVideo(modelId: string): boolean {
return this.getModelModalities(modelId)?.video ?? false;
}
/**
* Get model modalities as an array of ModelModality enum values
*/
getModelModalitiesArray(modelId: string): ModelModality[] {
const modalities = this.getModelModalities(modelId);
if (!modalities) return [];
const result: ModelModality[] = [];
if (modalities.vision) result.push(ModelModality.VISION);
if (modalities.audio) result.push(ModelModality.AUDIO);
if (modalities.video) result.push(ModelModality.VIDEO);
@ -193,16 +164,10 @@ class ModelsStore {
return result;
}
/**
* Get props for a specific model (from cache)
*/
getModelProps(modelId: string): ApiLlamaCppServerProps | null {
return this.modelPropsCache.get(modelId);
}
/**
* Get context size (n_ctx) for a specific model from cached props
*/
getModelContextSize(modelId: string): number | null {
const props = this.getModelProps(modelId);
const nCtx = props?.default_generation_settings?.n_ctx;
@ -210,17 +175,6 @@ class ModelsStore {
return typeof nCtx === 'number' ? nCtx : null;
}
/**
* Get context size for the currently selected model or null if no model is selected
*/
get selectedModelContextSize(): number | null {
if (!this.selectedModelName) return null;
return this.getModelContextSize(this.selectedModelName);
}
/**
* Check if props are being fetched for a model
*/
isModelPropsFetching(modelId: string): boolean {
return this.modelPropsFetching.has(modelId);
}
@ -235,10 +189,10 @@ class ModelsStore {
isModelLoaded(modelId: string): boolean {
const model = this.routerModels.find((m) => m.id === modelId);
return (
model?.status.value === ServerModelStatus.LOADED ||
model?.status.value === ServerModelStatus.SLEEPING ||
false
model?.status.value === ServerModelStatus.SLEEPING
);
}
@ -248,6 +202,7 @@ class ModelsStore {
getModelStatus(modelId: string): ServerModelStatus | null {
const model = this.routerModels.find((m) => m.id === modelId);
return model?.status.value ?? null;
}
@ -257,6 +212,7 @@ class ModelsStore {
isModelInUse(modelId: string): boolean {
const usage = this.modelUsage.get(modelId);
return usage !== undefined && usage.size > 0;
}
@ -269,8 +225,8 @@ class ModelsStore {
*/
/**
* Fetch list of models from server and detect server role
* Also fetches modalities for MODEL mode (single model)
* Fetch list of models from server and detect server role.
* Also fetches modalities for MODEL mode (single model).
*/
async fetch(force = false): Promise<void> {
if (this.inflightFetch) return this.inflightFetch;
@ -293,69 +249,87 @@ class ModelsStore {
await serverStore.fetch();
}
const response = await ModelsService.list();
const router = isRouterMode();
const models: ModelOption[] = response.data.map((item: ApiModelDataEntry, index: number) => {
const details = response.models?.[index];
const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : [];
const displayNameSource =
details?.name && details.name.trim().length > 0 ? details.name : item.id;
const displayName = this.toDisplayName(displayNameSource);
const modelId = details?.model || item.id;
if (router) {
const response = await ModelsService.listRouter();
return {
id: item.id,
name: displayName,
model: modelId,
description: details?.description,
capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)),
details: details?.details,
meta: item.meta ?? null,
parsedId: ModelsService.parseModelId(modelId),
aliases: item.aliases ?? [],
tags: item.tags ?? []
} satisfies ModelOption;
});
this.routerModels = response.data;
this.models = this.buildModelOptions(response);
this.models = models;
await this.fetchModalitiesForLoadedModels();
// WORKAROUND: In MODEL mode, /props returns modalities for the single model,
// but /v1/models doesn't include modalities. We bridge this gap here.
const serverProps = serverStore.props;
if (serverStore.isModelMode && this.models.length > 0 && serverProps?.modalities) {
const modalities: ModelModalities = {
vision: serverProps.modalities.vision ?? false,
audio: serverProps.modalities.audio ?? false,
video: serverProps.modalities.video ?? false
};
this.modelPropsCache.set(this.models[0].model, serverProps);
this.models = this.models.map((model, index) =>
index === 0 ? { ...model, modalities } : model
);
const visible = this.getVisibleModels();
if (visible.length === 1 && this.isModelLoaded(visible[0].model)) {
this.selectModelById(visible[0].id);
}
} else {
this.models = await this.fetchModelModeInternal();
}
} catch (error) {
this.models = [];
this.error = error instanceof Error ? error.message : 'Failed to load models';
throw error;
} finally {
this.loading = false;
}
}
/** Fetch models in MODEL mode (single model, standard OpenAI-compatible). */
private async fetchModelModeInternal(): Promise<ModelOption[]> {
const response = await ModelsService.list();
return this.buildModelOptions(response);
}
/**
* Fetch router models with full metadata (ROUTER mode only)
* This fetches the /models endpoint which returns status info for each model
* Build ModelOption[] from an API response.
* Both MODEL and ROUTER modes share the same mapping logic;
* they differ only in which endpoint is called.
*/
private buildModelOptions(
response: ApiModelListResponse | ApiRouterModelsListResponse
): ModelOption[] {
return response.data.map((item: ApiModelDataEntry, index: number) => {
const details = response.models?.[index];
const rawCapabilities = Array.isArray(details?.capabilities) ? details?.capabilities : [];
const displayNameSource =
details?.name && details.name.trim().length > 0 ? details.name : item.id;
const modelId = details?.model || item.id;
return {
id: item.id,
name: this.toDisplayName(displayNameSource),
model: modelId,
description: details?.description,
capabilities: rawCapabilities.filter((value: unknown): value is string => Boolean(value)),
details: details?.details,
meta: item.meta ?? null,
parsedId: ModelsService.parseModelId(modelId),
aliases: item.aliases ?? [],
tags: item.tags ?? []
};
});
}
/**
* Fetch router models with full metadata (ROUTER mode only).
* No-op in router mode fetch() already calls listRouter() internally.
* Kept for API compatibility (e.g. handleOpenChange dropdown open handler).
*/
async fetchRouterModels(): Promise<void> {
if (!isRouterMode()) return;
try {
const response = await ModelsService.listRouter();
this.routerModels = response.data;
await this.fetchModalitiesForLoadedModels();
const o = this.models.filter((option) => this.getModelProps(option.model)?.ui !== false);
if (o.length === 1 && this.isModelLoaded(o[0].model)) {
this.selectModelById(o[0].id);
const visible = this.getVisibleModels();
if (visible.length === 1 && this.isModelLoaded(visible[0].model)) {
this.selectModelById(visible[0].id);
}
} catch (error) {
console.warn('Failed to fetch router models:', error);
@ -364,10 +338,10 @@ class ModelsStore {
}
/**
* Fetch props for a specific model from /props endpoint
* Uses caching to avoid redundant requests
* Fetch props for a specific model from /props endpoint.
* Uses caching to avoid redundant requests.
*
* In ROUTER mode, this will only fetch props if the model is loaded,
* In ROUTER mode, this only fetches props if the model is loaded,
* since unloaded models return 400 from /props endpoint.
*
* @param modelId - Model identifier to fetch props for
@ -397,10 +371,7 @@ class ModelsStore {
}
}
/**
* Fetch modalities for all loaded models from /props endpoint
* This updates the modalities field in models array
*/
/** Fetch modalities for all loaded models from /props endpoint. */
async fetchModalitiesForLoadedModels(): Promise<void> {
const loadedModelIds = this.loadedModelIds;
if (loadedModelIds.length === 0) return;
@ -410,7 +381,6 @@ class ModelsStore {
try {
const results = await Promise.all(propsPromises);
// Update models with modalities
this.models = this.models.map((model) => {
const modelIndex = loadedModelIds.indexOf(model.model);
if (modelIndex === -1) return model;
@ -418,13 +388,7 @@ class ModelsStore {
const props = results[modelIndex];
if (!props?.modalities) return model;
const modalities: ModelModalities = {
vision: props.modalities.vision ?? false,
audio: props.modalities.audio ?? false,
video: props.modalities.video ?? false
};
return { ...model, modalities };
return { ...model, modalities: this.buildModalities(props.modalities) };
});
this.propsCacheVersion++;
@ -433,17 +397,38 @@ class ModelsStore {
}
}
/**
* Update modalities for a specific model.
* Called when a model is loaded or when we need fresh modality data.
*/
async updateModelModalities(modelId: string): Promise<void> {
const props = await this.fetchModelProps(modelId);
if (!props?.modalities) return;
this.models = this.models.map((model) =>
model.model === modelId
? { ...model, modalities: this.buildModalities(props.modalities!) }
: model
);
this.propsCacheVersion++;
}
/**
* Filter to models visible in the UI (ui !== false).
*/
private getVisibleModels(): ModelOption[] {
return this.models.filter((option) => this.getModelProps(option.model)?.ui !== false);
}
/**
* Gets the model name from the last assistant message in the active conversation.
* Iterates backward through messages to find the most recent message with a model.
* Used by both the chat page and settings page to maintain model consistency.
* @returns The model name or null if not found
*/
getModelFromLastAssistantResponse(): string | null {
const messages = conversationsStore.activeMessages;
if (!messages || messages.length === 0) return null;
// Iterate backward to find the last message with a model
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].model) {
return messages[i].model;
@ -456,22 +441,13 @@ class ModelsStore {
/**
* Auto-selects the model from the last assistant response if available and loaded.
* Returns true if a model was selected, false otherwise.
* This is used by the chat page to maintain model consistency across page navigation.
*/
async selectModelFromLastAssistantResponse(): Promise<boolean> {
const lastModel = this.getModelFromLastAssistantResponse();
if (!lastModel) return false;
// Skip if already selected
if (this.selectedModelName === lastModel) return false;
if (!lastModel || this.selectedModelName === lastModel) return false;
const matchingModel = this.models.find((option) => option.model === lastModel);
if (!matchingModel) return false;
if (!this.isModelLoaded(lastModel)) {
console.log('[modelsStore] last assistant model not loaded:', lastModel);
return false;
}
if (!matchingModel || !this.isModelLoaded(lastModel)) return false;
try {
await this.selectModelById(matchingModel.id);
@ -484,22 +460,17 @@ class ModelsStore {
}
/**
* Auto-selects the first available model if none is selected, and fetches its props.
* Auto-selects the first available model if none is selected.
* Prioritizes:
* 1. Model from active conversation's last assistant response (if loaded)
* 2. Model from active conversation's last assistant response (if not loaded)
* 3. First loaded model (not from active conversation)
* 4. First available model
* This is used to ensure default values are populated in settings pages.
*/
async ensureFirstModelSelected(): Promise<void> {
if (this.selectedModelName) return;
// Filter models that are visible in the UI
const availableModels = this.models.filter(
(option) => this.getModelProps(option.model)?.ui !== false
);
const availableModels = this.getVisibleModels();
if (availableModels.length === 0) return;
// Try to select model from last assistant response first
@ -515,7 +486,7 @@ class ModelsStore {
}
}
// Try to find a loaded model first
// Try a loaded model first
const loadedModel = availableModels.find((m) => this.isModelLoaded(m.model));
if (loadedModel) {
await this.selectModelById(loadedModel.id);
@ -524,34 +495,7 @@ class ModelsStore {
}
// Fall back to the first available model
const firstModel = availableModels[0];
await this.selectModelById(firstModel.id);
// Don't fetch props for unloaded models (will fail in ROUTER mode)
}
/**
* Update modalities for a specific model
* Called when a model is loaded or when we need fresh modality data
*/
async updateModelModalities(modelId: string): Promise<void> {
try {
const props = await this.fetchModelProps(modelId);
if (!props?.modalities) return;
const modalities: ModelModalities = {
vision: props.modalities.vision ?? false,
audio: props.modalities.audio ?? false,
video: props.modalities.video ?? false
};
this.models = this.models.map((model) =>
model.model === modelId ? { ...model, modalities } : model
);
this.propsCacheVersion++;
} catch (error) {
console.warn(`Failed to update modalities for model ${modelId}:`, error);
}
await this.selectModelById(availableModels[0].id);
}
/**
@ -562,9 +506,6 @@ class ModelsStore {
*
*/
/**
* Select a model for new conversations
*/
async selectModelById(modelId: string): Promise<void> {
if (!modelId || this.updating) return;
if (this.selectedModelId === modelId) return;
@ -584,8 +525,7 @@ class ModelsStore {
}
/**
* Select a model by its model name (used for syncing with conversation model)
* @param modelName - Model name to select (e.g., "ggml-org/GLM-4.7-Flash-GGUF")
* Select a model by its model name (used for syncing with conversation model).
*/
selectModelByName(modelName: string): void {
const option = this.models.find((model) => model.model === modelName);
@ -615,7 +555,7 @@ class ModelsStore {
/**
*
*
* Loading/Unloading Models
* Loading / Unloading Models
*
*
*/
@ -623,27 +563,18 @@ class ModelsStore {
/**
* WORKAROUND: Polling for model status after load/unload operations.
*
* Currently, the `/models/load` and `/models/unload` endpoints return success
* before the operation actually completes on the server. This means an immediate
* request to `/models` returns stale status (e.g., "loading" after load request,
* "loaded" after unload request).
* Currently, `/models/load` and `/models/unload` return success before
* the operation actually completes on the server.
*
* TODO: Remove this polling once llama-server properly waits for the operation
* to complete before returning success from `/load` and `/unload` endpoints.
* At that point, a single `fetchRouterModels()` call after the operation will
* be sufficient to get the correct status.
* TODO: Remove polling once llama-server properly waits for the operation
* to complete before returning success.
*/
/** Polling interval in ms for checking model status */
private static readonly STATUS_POLL_INTERVAL = 500;
/**
* Poll for expected model status after load/unload operation.
* Keeps polling indefinitely until the model reaches the expected status or fails.
*
* @param modelId - Model identifier to check
* @param expectedStatus - Expected status to wait for
* @throws Error if model reaches FAILED status
* Keeps polling until the model reaches the expected status or fails.
*/
private async pollForModelStatus(
modelId: string,
@ -654,9 +585,7 @@ class ModelsStore {
await this.fetchRouterModels();
const currentStatus = this.getModelStatus(modelId);
if (currentStatus === expectedStatus) {
return;
}
if (currentStatus === expectedStatus) return;
if (currentStatus === ServerModelStatus.FAILED) {
throw new Error(
@ -677,15 +606,8 @@ class ModelsStore {
}
}
/**
* Load a model (ROUTER mode)
* @param modelId - Model identifier to load
*/
async loadModel(modelId: string): Promise<void> {
if (this.isModelLoaded(modelId)) {
return;
}
if (this.isModelLoaded(modelId)) return;
if (this.modelLoadingStates.get(modelId)) return;
this.modelLoadingStates.set(modelId, true);
@ -694,7 +616,6 @@ class ModelsStore {
try {
await ModelsService.load(modelId);
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
await this.updateModelModalities(modelId);
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
} catch (error) {
@ -706,15 +627,8 @@ class ModelsStore {
}
}
/**
* Unload a model (ROUTER mode)
* @param modelId - Model identifier to unload
*/
async unloadModel(modelId: string): Promise<void> {
if (!this.isModelLoaded(modelId)) {
return;
}
if (!this.isModelLoaded(modelId)) return;
if (this.modelLoadingStates.get(modelId)) return;
this.modelLoadingStates.set(modelId, true);
@ -722,7 +636,6 @@ class ModelsStore {
try {
await ModelsService.unload(modelId);
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
} catch (error) {
@ -734,15 +647,8 @@ class ModelsStore {
}
}
/**
* Ensure a model is loaded before use
* @param modelId - Model identifier to ensure is loaded
*/
async ensureModelLoaded(modelId: string): Promise<void> {
if (this.isModelLoaded(modelId)) {
return;
}
if (this.isModelLoaded(modelId)) return;
await this.loadModel(modelId);
}
@ -779,11 +685,9 @@ class ModelsStore {
private loadFavoritesFromStorage(): Set<string> {
try {
const raw = localStorage.getItem(FAVORITE_MODELS_LOCALSTORAGE_KEY);
return raw ? new Set(JSON.parse(raw) as string[]) : new Set();
} catch {
toast.error('Failed to load favorite models from local storage');
return new Set();
}
}
@ -799,10 +703,19 @@ class ModelsStore {
private toDisplayName(id: string): string {
const segments = id.split(/\\|\//);
const candidate = segments.pop();
return candidate && candidate.trim().length > 0 ? candidate : id;
}
private buildModalities(
modalities: NonNullable<ApiLlamaCppServerProps['modalities']>
): ModelModalities {
return {
vision: modalities.vision ?? false,
audio: modalities.audio ?? false,
video: modalities.video ?? false
};
}
clear(): void {
this.models = [];
this.routerModels = [];

View file

@ -203,6 +203,7 @@ export interface ApiLlamaCppServerProps {
/** @deprecated Use {@link ui_settings} instead */
webui_settings?: Record<string, string | number | boolean>;
ui_settings?: Record<string, string | number | boolean>;
cors_proxy_enabled?: boolean;
}
export interface ApiChatCompletionRequest {

View file

@ -1,5 +1,5 @@
import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp';
import type { ToolSource } from '$lib/enums/tools';
import type { MCPConnectionPhase, MCPLogLevel, HealthCheckStatus } from '$lib/enums/mcp.enums';
import type { ToolSource } from '$lib/enums/tools.enums';
import type {
Client,
ClientCapabilities as SDKClientCapabilities,

View file

@ -12,17 +12,21 @@ export async function validateApiKey(fetch: typeof globalThis.fetch): Promise<vo
return;
}
const apiKey = config().apiKey;
// No API key configured — server doesn't require auth, skip the request entirely.
// The /props endpoint is only protected when the server has API keys configured,
// and in that case the client always has one set (from settings).
if (!apiKey) {
return;
}
try {
const apiKey = config().apiKey;
const headers: Record<string, string> = {
'Content-Type': 'application/json'
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`
};
if (apiKey) {
headers.Authorization = `Bearer ${apiKey}`;
}
const response = await fetch(`${base}/props`, { headers });
if (!response.ok) {

View file

@ -333,7 +333,8 @@ async function migrateConversation(convId: string): Promise<number> {
export async function runLegacyMigration(): Promise<void> {
if (!isMigrationNeeded()) return;
console.log('[Migration] Starting legacy message format migration...');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG)
console.log('[Migration] Starting legacy message format migration...');
try {
const conversations = await DatabaseService.getAllConversations();
@ -344,12 +345,14 @@ export async function runLegacyMigration(): Promise<void> {
totalMigrated += count;
}
if (totalMigrated > 0) {
console.log(
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
);
} else {
console.log('[Migration] No legacy messages found, marking as done');
if (import.meta.env.DEV && import.meta.env.VITE_DEBUG) {
if (totalMigrated > 0) {
console.log(
`[Migration] Migrated ${totalMigrated} messages across ${conversations.length} conversations`
);
} else {
console.log('[Migration] No legacy messages found, marking as done');
}
}
markMigrationDone();

View file

@ -3,7 +3,6 @@
import { chatStore } from '$lib/stores/chat.svelte';
import { conversationsStore, isConversationsInitialized } from '$lib/stores/conversations.svelte';
import { modelsStore, modelOptions } from '$lib/stores/models.svelte';
import { isRouterMode } from '$lib/stores/server.svelte';
import { onMount } from 'svelte';
import { page } from '$app/state';
import { replaceState } from '$app/navigation';
@ -72,23 +71,13 @@
conversationsStore.clearActiveConversation();
chatStore.clearUIState();
if (
isRouterMode() &&
modelsStore.selectedModelName &&
!modelsStore.isModelLoaded(modelsStore.selectedModelName)
) {
modelsStore.clearSelection();
await modelsStore.fetch();
const first = modelOptions().find((m) => modelsStore.loadedModelIds.includes(m.model));
if (first) {
await modelsStore.selectModelById(first.id);
}
}
// Handle URL params only if we have ?q= or ?model= or ?new_chat=true
if (qParam !== null || modelParam !== null || newChatParam === 'true') {
await handleUrlParams();
}
await modelsStore.ensureFirstModelSelected();
});
</script>

View file

@ -7,11 +7,13 @@
import { untrack } from 'svelte';
import { onMount } from 'svelte';
import { fade } from 'svelte/transition';
import {
DesktopIconStrip,
DialogConversationTitleUpdate,
SidebarNavigation
} from '$lib/components/app';
import { conversationsStore } from '$lib/stores/conversations.svelte';
import * as Sidebar from '$lib/components/ui/sidebar/index.js';
import * as Tooltip from '$lib/components/ui/tooltip';
@ -30,26 +32,29 @@
import { conversations } from '$lib/stores/conversations.svelte';
let { children } = $props();
let alwaysShowSidebarOnDesktop = $derived(config().alwaysShowSidebarOnDesktop);
let isMobile = new IsMobile();
let isDesktop = $derived(!isMobile.current);
let sidebarOpen = $state(false);
let mounted = $state(false);
let innerHeight = $state<number | undefined>();
let chatSidebar:
| { activateSearchMode?: () => void; editActiveConversation?: () => void }
| {
activateSearchMode?: () => void;
editActiveConversation?: () => void;
}
| undefined = $state();
let titleUpdateDialogOpen = $state(false);
let titleUpdateCurrentTitle = $state('');
let titleUpdateNewTitle = $state('');
let titleUpdateResolve: ((value: boolean) => void) | null = null;
const panelNav = useSettingsNavigation();
function navigateToConversation(direction: -1 | 1) {
const allConvs = conversations();
if (allConvs.length === 0) return;
const currentId = page.params.id;
@ -61,6 +66,7 @@
}
const idx = allConvs.findIndex((c) => c.id === currentId);
if (idx === -1) return;
const targetIdx = idx + direction;
@ -75,38 +81,41 @@
// Global keyboard shortcuts
const { handleKeydown } = useKeyboardShortcuts({
editActiveConversation: () => chatSidebar?.editActiveConversation?.(),
navigateToPrevConversation: () => navigateToConversation(-1),
navigateToNextConversation: () => navigateToConversation(1)
});
function checkApiKey() {
const apiKey = config().apiKey;
if (
(page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') &&
page.status !== 401 &&
page.status !== 403
) {
const headers: Record<string, string> = {
'Content-Type': 'application/json'
};
if (apiKey && apiKey.trim() !== '') {
headers.Authorization = `Bearer ${apiKey.trim()}`;
}
fetch(`${base}/props`, { headers })
.then((response) => {
if (response.status === 401 || response.status === 403) {
window.location.reload();
}
})
.catch((e) => {
console.error('Error checking API key:', e);
});
// No API key configured — server doesn't require auth, no need to validate.
// This mirrors the early return in validateApiKey() to avoid redundant /props requests.
if (!apiKey || apiKey.trim() === '') {
return;
}
untrack(() => {
if (
(page.route.id === '/(chat)' || page.route.id === '/(chat)/chat/[id]') &&
page.status !== 401 &&
page.status !== 403
) {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey.trim()}`
};
fetch(`${base}/props`, { headers })
.then((response) => {
if (response.status === 401 || response.status === 403) {
window.location.reload();
}
})
.catch((e) => {
console.error('Error checking API key:', e);
});
}
});
}
function handleTitleUpdateCancel() {
@ -134,6 +143,7 @@
$effect(() => {
if (alwaysShowSidebarOnDesktop && isDesktop) {
sidebarOpen = true;
return;
}
});
@ -170,6 +180,7 @@
// Only fetch router models once when we have models loaded and in router mode
if (isRouter && modelsCount > 0 && !routerModelsFetched) {
routerModelsFetched = true;
untrack(() => {
modelsStore.fetchRouterModels();
});
@ -218,7 +229,6 @@
<Tooltip.Provider delayDuration={TOOLTIP_DELAY_DURATION}>
<ModeWatcher />
<Toaster richColors />
<DialogConversationTitleUpdate
@ -230,10 +240,10 @@
/>
<Sidebar.Provider bind:open={sidebarOpen}>
<div class="flex h-screen w-full" style:height="{innerHeight}px">
<Sidebar.Root variant="floating" class="h-full">
<SidebarNavigation bind:this={chatSidebar} />
</Sidebar.Root>
<div class="flex h-screen w-full">
<Sidebar.Root variant="floating" class="h-full"
><SidebarNavigation bind:this={chatSidebar} /></Sidebar.Root
>
{#if !(alwaysShowSidebarOnDesktop && isDesktop) && !(panelNav.isSettingsRoute && !isDesktop)}
{#if mounted}
@ -261,9 +271,9 @@
/>
{/if}
<Sidebar.Inset class="flex flex-1 flex-col overflow-hidden">
{@render children?.()}
</Sidebar.Inset>
<Sidebar.Inset class="flex flex-1 flex-col overflow-hidden"
>{@render children?.()}</Sidebar.Inset
>
</div>
</Sidebar.Provider>
</Tooltip.Provider>

View file

@ -8,6 +8,9 @@ $use-ttf: false;
$font-folder: 'katex-fonts';
// Import KaTeX SCSS with overridden variables
// Note: @import is deprecated but required because KaTeX uses @import internally
// The deprecation warnings are from KaTeX's code and cannot be avoided
@import 'katex/src/styles/katex.scss';
@use 'katex/src/styles/katex.scss' with (
$use-woff2: true,
$use-woff: false,
$use-ttf: false,
$font-folder: 'katex-fonts'
);

View file

@ -4,8 +4,9 @@ import TestWrapper from './components/TestWrapper.svelte';
describe('/+page.svelte', () => {
it('should render page without throwing', async () => {
// Basic smoke test - page should render without throwing errors
// API calls will fail in test environment but component should still mount
expect(() => render(TestWrapper)).not.toThrow();
// Basic smoke test - page should render without throwing errors.
// API calls are mocked in vitest-setup-client.ts.
await render(TestWrapper);
expect(true).toBe(true);
});
});

View file

@ -23,18 +23,6 @@ export default defineConfig({
minify: true
},
css: {
preprocessorOptions: {
scss: {
additionalData: `
$use-woff2: true;
$use-woff: false;
$use-ttf: false;
`
}
}
},
plugins: [tailwindcss(), sveltekit(), devtoolsJson(), llamaCppBuildPlugin()],
test: {

View file

@ -1,2 +1,80 @@
/// <reference types="@vitest/browser/matchers" />
/// <reference types="@vitest/browser/providers/playwright" />
import { beforeEach, vi } from 'vitest';
// Mock fetch for API calls during client tests.
// In test environment there is no backend server, so we intercept
// the specific endpoints the app uses and return valid mock data.
beforeEach(() => {
const originalFetch = globalThis.fetch;
vi.spyOn(globalThis, 'fetch').mockImplementation(
async (input: RequestInfo | URL, init?: RequestInit) => {
const url = typeof input === 'string' ? input : input instanceof URL ? input.href : input.url;
// Mock server props endpoint
if (url.includes('/server')) {
return new Response(
JSON.stringify({
mode: 'router',
version: 'test',
git_commit: 'test',
git_branch: 'test'
}),
{ status: 200, headers: { 'Content-Type': 'application/json' } }
);
}
// Mock models list endpoint
if (/\/v1\/models|\/models\b/.test(url)) {
return new Response(
JSON.stringify({
object: 'list',
data: [
{
id: 'test-model.gguf',
object: 'model',
owned_by: 'llamacpp',
created: 0,
in_cache: false,
path: 'models/test-model.gguf',
status: { value: 'unloaded' },
meta: {}
}
],
models: [
{
model: 'test-model.gguf',
name: 'Test Model',
details: {}
}
]
}),
{ status: 200, headers: { 'Content-Type': 'application/json' } }
);
}
// Mock /props endpoint (used for modalities)
if (url.includes('/props')) {
return new Response(
JSON.stringify({
default_generation_settings: { n_ctx: 2048 }
}),
{ status: 200, headers: { 'Content-Type': 'application/json' } }
);
}
// Mock /tools endpoint (used for built-in tools list)
if (url.includes('/tools')) {
return new Response(JSON.stringify([]), {
status: 200,
headers: { 'Content-Type': 'application/json' }
});
}
// Default: use real fetch
return originalFetch(input, init);
}
);
});

1
tools/ui/vitest.shims.d.ts vendored Normal file
View file

@ -0,0 +1 @@
/// <reference types="@vitest/browser-playwright" />