From 7d8aa31f1ff715e3dd8eab0cffd089a47f743061 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:11:55 +0800 Subject: [PATCH] fixed embeddings, added new parameter to limit max embeddings context --- Makefile | 2 + examples/embedding/embedding.cpp | 338 +++++++++++++++++++++++++++++++ expose.h | 1 + koboldcpp.py | 13 +- otherarch/embeddings_adapter.cpp | 13 +- src/llama-context.cpp | 2 +- 6 files changed, 360 insertions(+), 9 deletions(-) create mode 100644 examples/embedding/embedding.cpp diff --git a/Makefile b/Makefile index 9c71cb84f..828f511a2 100644 --- a/Makefile +++ b/Makefile @@ -699,6 +699,8 @@ mtmd-cli: tools/mtmd/mtmd-cli.cpp tools/mtmd/mtmd.cpp common/arg.cpp build-info. $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) mainvk: tools/main/main.cpp common/arg.cpp build-info.h ggml_v4_vulkan.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o ggml-vulkan.o $(OBJS_FULL) $(OBJS) lib/vulkan-1.lib $(CXX) $(CXXFLAGS) -DGGML_USE_VULKAN -DSD_USE_VULKAN $(filter-out %.h,$^) -o $@ $(LDFLAGS) +embedding: examples/embedding/embedding.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o ggml-ops.o ggml-vec.o ggml-binops.o ggml-unops.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ggml/src/ggml-vulkan-shaders.cpp: ifdef VULKAN_BUILD diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp new file mode 100644 index 000000000..ea6da69ba --- /dev/null +++ b/examples/embedding/embedding.cpp @@ -0,0 +1,338 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static std::vector split_lines(const std::string & s, const std::string & separator = "\n") { + std::vector lines; + size_t start = 0; + size_t end = s.find(separator); + + while (end != std::string::npos) { + lines.push_back(s.substr(start, end - start)); + start = end + separator.length(); + end = s.find(separator, start); + } + + lines.push_back(s.substr(start)); // Add the last part + + return lines; +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } +} + +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + // clear previous kv_cache values (irrelevant for embeddings) + llama_memory_clear(llama_get_memory(ctx)); + + // run model + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); + } + + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings + embd = llama_get_embeddings_ith(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd_pos = batch.seq_id[i][0]; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + } + + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { + return 1; + } + + common_init(); + + params.embedding = true; + + // utilize the full context + if (params.n_batch < params.n_ctx) { + LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx); + params.n_batch = params.n_ctx; + } + + // For non-causal models, batch size must be equal to ubatch size + params.n_ubatch = params.n_batch; + + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__); + return 1; + } + + if (n_ctx > n_ctx_train) { + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + // split the prompt into lines + std::vector prompts = split_lines(params.prompt, params.embd_sep); + + // max batch size + const uint64_t n_batch = params.n_batch; + + // tokenize the prompts and trim + std::vector> inputs; + for (const auto & prompt : prompts) { + auto inp = common_tokenize(ctx, prompt, true, true); + if (inp.size() > n_batch) { + LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + __func__, (long long int) inp.size(), (long long int) n_batch); + return 1; + } + inputs.push_back(inp); + } + + // check if the last token is SEP + // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' + for (auto & inp : inputs) { + if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { + LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); + LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); + } + } + + // tokenization stats + if (params.verbose_prompt) { + for (int i = 0; i < (int) inputs.size(); i++) { + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + for (int j = 0; j < (int) inputs[i].size(); j++) { + LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str()); + } + LOG("\n\n"); + } + } + + // initialize batch + const int n_prompts = prompts.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + + // count number of embeddings + int n_embd_count = 0; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int k = 0; k < n_prompts; k++) { + n_embd_count += inputs[k].size(); + } + } else { + n_embd_count = n_prompts; + } + + // allocate output + const int n_embd = llama_model_n_embd(model); + std::vector embeddings(n_embd_count * n_embd, 0); + float * emb = embeddings.data(); + + // break into batches + int e = 0; // number of embeddings already stored + int s = 0; // number of prompts in current batch + for (int k = 0; k < n_prompts; k++) { + // clamp to n_batch tokens + auto & inp = inputs[k]; + + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch) { + float * out = emb + e * n_embd; + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + s = 0; + common_batch_clear(batch); + } + + // add to batch + batch_add_seq(batch, inp, s); + s += 1; + } + + // final batch + float * out = emb + e * n_embd; + batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + + if (params.embd_out.empty()) { + LOG("\n"); + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int j = 0; j < n_embd_count; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < std::min(3, n_embd); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG(" ... "); + for (int i = n_embd - 3; i < n_embd; i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG("\n"); + } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + std::vector cls_out_labels; + + for (uint32_t i = 0; i < n_cls_out; i++) { + const char * label = llama_model_cls_label(model, i); + const std::string label_i(label == nullptr ? "" : label); + cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i); + } + + for (int j = 0; j < n_embd_count; j++) { + for (uint32_t i = 0; i < n_cls_out; i++) { + // NOTE: if you change this log - update the tests in ci/run.sh + if (n_cls_out == 1) { + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } else { + LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str()); + } + } + } + } else { + // print the first part of the embeddings or for a single prompt, the full embedding + for (int j = 0; j < n_prompts; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd + i]); + } else { + LOG("%9.6f ", emb[j * n_embd + i]); + } + } + LOG("\n"); + } + + // print cosine similarity matrix + if (n_prompts > 1) { + LOG("\n"); + LOG("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + LOG("%6.6s ", prompts[i].c_str()); + } + LOG("\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f ", sim); + } + LOG("%1.10s", prompts[i].c_str()); + LOG("\n"); + } + } + } + } + + if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { + const bool notArray = params.embd_out != "array"; + + LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); + for (int j = 0;;) { // at least one iteration (one prompt) + if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); + LOG("["); + for (int i = 0;;) { // at least one iteration (n_embd > 0) + LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); + i++; + if (i < n_embd) LOG(","); else break; + } + LOG(notArray ? "]\n }" : "]"); + j++; + if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break; + } + LOG(notArray ? "\n ]" : "]\n"); + + if (params.embd_out == "json+" && n_prompts > 1) { + LOG(",\n \"cosineSimilarity\": [\n"); + for (int i = 0;;) { // at least two iteration (n_embd_count > 1) + LOG(" ["); + for (int j = 0;;) { // at least two iteration (n_embd_count > 1) + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f", sim); + j++; + if (j < n_embd_count) LOG(", "); else break; + } + LOG(" ]"); + i++; + if (i < n_embd_count) LOG(",\n"); else break; + } + LOG("\n ]"); + } + + if (notArray) LOG("\n}\n"); + } + + LOG("\n"); + llama_perf_context_print(ctx); + + // clean up + llama_batch_free(batch); + llama_backend_free(); + + return 0; +} \ No newline at end of file diff --git a/expose.h b/expose.h index aafc66fc3..9d03b30c4 100644 --- a/expose.h +++ b/expose.h @@ -257,6 +257,7 @@ struct embeddings_load_model_inputs const int gpulayers = 0; const bool flash_attention = false; const bool use_mmap = false; + const int embeddingsmaxctx = 0; const bool quiet = false; const int debugmode = 0; }; diff --git a/koboldcpp.py b/koboldcpp.py index 678ccc4af..988d140b4 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -56,7 +56,7 @@ logit_bias_max = 512 dry_seq_break_max = 128 # global vars -KcppVersion = "1.93.1" +KcppVersion = "1.93.2" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False} @@ -350,6 +350,7 @@ class embeddings_load_model_inputs(ctypes.Structure): ("gpulayers", ctypes.c_int), ("flash_attention", ctypes.c_bool), ("use_mmap", ctypes.c_bool), + ("embeddingsmaxctx", ctypes.c_int), ("quiet", ctypes.c_bool), ("debugmode", ctypes.c_int)] @@ -1742,6 +1743,7 @@ def embeddings_load_model(model_filename): inputs.flash_attention = False inputs.threads = args.threads inputs.use_mmap = args.usemmap + inputs.embeddingsmaxctx = args.embeddingsmaxctx inputs = set_backend_props(inputs) ret = handle.embeddings_load_model(inputs) return ret @@ -4245,6 +4247,7 @@ def show_gui(): ttsmaxlen_var = ctk.StringVar(value=str(default_ttsmaxlen)) embeddings_model_var = ctk.StringVar() + embeddings_ctx_var = ctk.StringVar(value=str("")) admin_var = ctk.IntVar(value=0) admin_dir_var = ctk.StringVar() @@ -4852,7 +4855,8 @@ def show_gui(): makelabelentry(model_tab, "Draft Amount: ", draftamount_var, 13, 50,padx=100,singleline=True,tooltip="How many tokens to draft per chunk before verifying results") makelabelentry(model_tab, "Splits: ", draftgpusplit_str_vars, 13, 50,padx=210,singleline=True,tooltip="Distribution of draft model layers. Leave blank to follow main model's gpu split. Only works if multi-gpu (All) selected in main model.", labelpadx=160) makelabelentry(model_tab, "Layers: ", draftgpulayers_var, 13, 50,padx=320,singleline=True,tooltip="How many layers to GPU offload for the draft model", labelpadx=270) - makefileentry(model_tab, "Embeds Model:", "Select Embeddings Model File", embeddings_model_var, 15, width=280,singlerow=True, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select an embeddings GGUF model that can be used to generate embedding vectors.") + makefileentry(model_tab, "Embeds Model:", "Select Embeddings Model File", embeddings_model_var, 15, width=160,singlerow=True, filetypes=[("*.gguf","*.gguf")], tooltiptxt="Select an embeddings GGUF model that can be used to generate embedding vectors.") + makelabelentry(model_tab, "EmbdCtx: ", embeddings_ctx_var, 15, 50,padx=390,singleline=True,tooltip="If set above 0, limits max context for embedding model to save memory.", labelpadx=330) makefileentry(model_tab, "Preload Story:", "Select Preloaded Story File", preloadstory_var, 17,width=280,singlerow=True,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.") makefileentry(model_tab, "SaveData File:", "Select or Create New SaveData Database File", savedatafile_var, 19,width=280,filetypes=[("KoboldCpp SaveDB", "*.jsondb")],singlerow=True,dialog_type=1,tooltiptxt="Selecting a file will allow data to be loaded and saved persistently to this KoboldCpp server remotely. File is created if it does not exist.") makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 24, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.") @@ -5207,6 +5211,9 @@ def show_gui(): if embeddings_model_var.get() != "": args.embeddingsmodel = embeddings_model_var.get() + if embeddings_ctx_var.get() != "": + args.embeddingsmaxctx = (0 if embeddings_ctx_var.get()=="" else int(embeddings_ctx_var.get())) + if tts_model_var.get() != "" and wavtokenizer_var.get() != "": args.ttsthreads = (0 if tts_threads_var.get()=="" else int(tts_threads_var.get())) args.ttsmodel = tts_model_var.get() @@ -5401,6 +5408,7 @@ def show_gui(): ttsmaxlen_var.set(str(dict["ttsmaxlen"]) if ("ttsmaxlen" in dict and dict["ttsmaxlen"]) else str(default_ttsmaxlen)) embeddings_model_var.set(dict["embeddingsmodel"] if ("embeddingsmodel" in dict and dict["embeddingsmodel"]) else "") + embeddings_ctx_var.set(str(dict["embeddingsmaxctx"]) if ("embeddingsmaxctx" in dict and dict["embeddingsmaxctx"]) else "") admin_var.set(dict["admin"] if ("admin" in dict) else 0) admin_dir_var.set(dict["admindir"] if ("admindir" in dict and dict["admindir"]) else "") @@ -7139,6 +7147,7 @@ if __name__ == '__main__': embeddingsparsergroup = parser.add_argument_group('Embeddings Model Commands') embeddingsparsergroup.add_argument("--embeddingsmodel", metavar=('[filename]'), help="Specify an embeddings model to be loaded for generating embedding vectors.", default="") + embeddingsparsergroup.add_argument("--embeddingsmaxctx", metavar=('[amount]'), help="Overrides the default maximum supported context of an embeddings model (defaults to trained context).", type=int, default=0) admingroup = parser.add_argument_group('Administration Commands') admingroup.add_argument("--admin", help="Enables admin mode, allowing you to unload and reload different configurations or models.", action='store_true') diff --git a/otherarch/embeddings_adapter.cpp b/otherarch/embeddings_adapter.cpp index 9bbc0f1f0..ac254e656 100644 --- a/otherarch/embeddings_adapter.cpp +++ b/otherarch/embeddings_adapter.cpp @@ -34,7 +34,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const struct llama_model * model = llama_get_model(ctx); @@ -48,7 +48,7 @@ static void batch_encode(llama_context * ctx, llama_batch & batch, float * outpu } // run model - if (llama_encode(ctx, batch) < 0) { + if (llama_decode(ctx, batch) < 0) { printf("%s : failed to process\n", __func__); } @@ -125,7 +125,8 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs) llama_model * embeddingsmodel = llama_model_load_from_file(modelfile.c_str(), model_params); const int n_ctx_train = llama_model_n_ctx_train(embeddingsmodel); - max_batchsize = n_ctx_train; + max_batchsize = (inputs.embeddingsmaxctx>0?inputs.embeddingsmaxctx:n_ctx_train); + max_batchsize = (max_batchsize>n_ctx_train?n_ctx_train:max_batchsize); ctx_params.embeddings = true; ctx_params.n_ubatch = max_batchsize; //max size, must fit ctx_params.n_ctx = max_batchsize; @@ -144,7 +145,7 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs) std::vector tmp = {1, 2, 3, 4}; llama_kv_self_clear(embeddings_ctx); - auto er = llama_encode(embeddings_ctx, llama_batch_get_one(tmp.data(), tmp.size())); + auto er = llama_decode(embeddings_ctx, llama_batch_get_one(tmp.data(), tmp.size())); if(er!=0) { printf("\nEmbeddings Model Eval returned nonzero: %d\n",er); @@ -259,7 +260,7 @@ embeddings_generation_outputs embeddingstype_generate(const embeddings_generatio // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + e * n_embd; - batch_encode(embeddings_ctx, batch, out, s, n_embd, embd_normalize); + batch_decode(embeddings_ctx, batch, out, s, n_embd, embd_normalize); e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; s = 0; common_batch_clear(batch); @@ -271,7 +272,7 @@ embeddings_generation_outputs embeddingstype_generate(const embeddings_generatio // final batch float * out = emb + e * n_embd; - batch_encode(embeddings_ctx, batch, out, s, n_embd, embd_normalize); + batch_decode(embeddings_ctx, batch, out, s, n_embd, embd_normalize); std::string outputarray = "["; for (int i = 0; i < n_embd; i++) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3e8c96b63..ea9064204 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -890,7 +890,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); + //LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(inp_batch); }