From f0d46ef15717cd609a7b69cf6190edde64d466c8 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 12 May 2025 13:13:49 -0700 Subject: [PATCH 1/6] opencl: remove unnecessary assert for `add` (#13257) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 05a2f4e63..586946048 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4855,8 +4855,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor if (!any_on_device) { return false; } - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); func = ggml_cl_add; break; case GGML_OP_MUL: From cf0a43bb6490bd49344775abb22ba26f8047cb54 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Mon, 12 May 2025 15:31:37 -0700 Subject: [PATCH 2/6] llama-bench : add defrag-thold, check for invalid ranges (#13487) --- include/llama.h | 2 +- tools/llama-bench/README.md | 7 ++-- tools/llama-bench/llama-bench.cpp | 55 ++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/include/llama.h b/include/llama.h index abedebdb7..99e5fba24 100644 --- a/include/llama.h +++ b/include/llama.h @@ -345,7 +345,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default) + float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; diff --git a/tools/llama-bench/README.md b/tools/llama-bench/README.md index 4fb2a24e1..0479f81a3 100644 --- a/tools/llama-bench/README.md +++ b/tools/llama-bench/README.md @@ -43,12 +43,13 @@ test parameters: -ub, --ubatch-size (default: 512) -ctk, --cache-type-k (default: f16) -ctv, --cache-type-v (default: f16) - -t, --threads (default: 16) + -dt, --defrag-thold (default: -1) + -t, --threads (default: system dependent) -C, --cpu-mask (default: 0x0) --cpu-strict <0|1> (default: 0) --poll <0...100> (default: 50) -ngl, --n-gpu-layers (default: 99) - -rpc, --rpc (default: ) + -rpc, --rpc (default: none) -sm, --split-mode (default: layer) -mg, --main-gpu (default: 0) -nkvo, --no-kv-offload <0|1> (default: 0) @@ -62,7 +63,7 @@ test parameters: Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times. Ranges can be given as -'start-end' or 'start-end+step' or 'start-end*mult'. +'first-last' or 'first-last+step' or 'first-last*mult'. ``` llama-bench can perform three types of tests: diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index ca0d0aed5..9457e6815 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -211,6 +211,8 @@ static std::vector parse_int_range(const std::string & s) { for (int i = first; i <= last;) { result.push_back(i); + int prev_i = i; + if (op == '+') { i += step; } else if (op == '*') { @@ -218,6 +220,10 @@ static std::vector parse_int_range(const std::string & s) { } else { throw std::invalid_argument("invalid range format"); } + + if (i <= prev_i) { + throw std::invalid_argument("invalid range"); + } } search_start = match.suffix().first; } @@ -239,6 +245,7 @@ struct cmd_params { std::vector n_ubatch; std::vector type_k; std::vector type_v; + std::vector defrag_thold; std::vector n_threads; std::vector cpu_mask; std::vector cpu_strict; @@ -274,6 +281,7 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ { 512 }, /* type_k */ { GGML_TYPE_F16 }, /* type_v */ { GGML_TYPE_F16 }, + /* defrag_thold */ { -1.0f }, /* n_threads */ { cpu_get_num_math() }, /* cpu_mask */ { "0x0" }, /* cpu_strict */ { false }, @@ -335,6 +343,8 @@ static void print_usage(int /* argc */, char ** argv) { join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); + printf(" -dt, --defrag-thold (default: %s)\n", + join(cmd_params_defaults.defrag_thold, ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -C, --cpu-mask (default: %s)\n", @@ -368,7 +378,7 @@ static void print_usage(int /* argc */, char ** argv) { printf( "Multiple values can be given for each parameter by separating them with ','\n" "or by specifying the parameter multiple times. Ranges can be given as\n" - "'start-end' or 'start-end+step' or 'start-end*mult'.\n"); + "'first-last' or 'first-last+step' or 'first-last*mult'.\n"); } static ggml_type ggml_type_from_name(const std::string & s) { @@ -519,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.type_v.insert(params.type_v.end(), types.begin(), types.end()); + } else if (arg == "-dt" || arg == "--defrag-thold") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end()); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -825,6 +842,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } + if (params.defrag_thold.empty()) { + params.defrag_thold = cmd_params_defaults.defrag_thold; + } if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } @@ -883,6 +903,7 @@ struct cmd_params_instance { int n_ubatch; ggml_type type_k; ggml_type type_v; + float defrag_thold; int n_threads; std::string cpu_mask; bool cpu_strict; @@ -959,15 +980,16 @@ struct cmd_params_instance { llama_context_params to_llama_cparams() const { llama_context_params cparams = llama_context_default_params(); - cparams.n_ctx = n_prompt + n_gen + n_depth; - cparams.n_batch = n_batch; - cparams.n_ubatch = n_ubatch; - cparams.type_k = type_k; - cparams.type_v = type_v; - cparams.offload_kqv = !no_kv_offload; - cparams.flash_attn = flash_attn; - cparams.embeddings = embeddings; - cparams.op_offload = !no_op_offload; + cparams.n_ctx = n_prompt + n_gen + n_depth; + cparams.n_batch = n_batch; + cparams.n_ubatch = n_ubatch; + cparams.type_k = type_k; + cparams.type_v = type_v; + cparams.defrag_thold = defrag_thold; + cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; + cparams.embeddings = embeddings; + cparams.op_offload = !no_op_offload; return cparams; } @@ -992,6 +1014,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & nub : params.n_ubatch) for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) + for (const auto & defrag_thold : params.defrag_thold) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) @@ -1012,6 +1035,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, + /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1044,6 +1068,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, + /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1076,6 +1101,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .n_ubatch = */ nub, /* .type_k = */ tk, /* .type_v = */ tv, + /* .defrag_thold = */ defrag_thold, /* .n_threads = */ nt, /* .cpu_mask = */ cm, /* .cpu_strict = */ cs, @@ -1117,6 +1143,7 @@ struct test { int poll; ggml_type type_k; ggml_type type_v; + float defrag_thold; int n_gpu_layers; llama_split_mode split_mode; int main_gpu; @@ -1151,6 +1178,7 @@ struct test { poll = inst.poll; type_k = inst.type_k; type_v = inst.type_v; + defrag_thold = inst.defrag_thold; n_gpu_layers = inst.n_gpu_layers; split_mode = inst.split_mode; main_gpu = inst.main_gpu; @@ -1206,6 +1234,7 @@ struct test { "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", + "defrag_thold", "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", }; @@ -1225,7 +1254,7 @@ struct test { field == "use_mmap" || field == "embeddings") { return BOOL; } - if (field == "avg_ts" || field == "stddev_ts") { + if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") { return FLOAT; } return STRING; @@ -1292,6 +1321,7 @@ struct test { std::to_string(flash_attn), tensor_split_str, tensor_buft_overrides_str, + std::to_string(defrag_thold), std::to_string(use_mmap), std::to_string(embeddings), std::to_string(no_op_offload), @@ -1558,6 +1588,9 @@ struct markdown_printer : public printer { if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) { fields.emplace_back("type_v"); } + if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) { + fields.emplace_back("defrag_thold"); + } if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) { fields.emplace_back("main_gpu"); } From 1e2809bc4b5d8db2a9ed12ac872eca832c53f5fd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 May 2025 14:01:45 +0300 Subject: [PATCH 3/6] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 1f7c650c2..ddd884d37 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b59bddafe278877dfa22a80e53a637513862babb +9b048bb72b811f50b0c30d9e5c84d6ff9f4bf005 From d590cd4c244e5f260c42c290b83a358b9d86d763 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 07:12:01 -0600 Subject: [PATCH 4/6] model : Granite MoE shared (#13269) * feat: Add GGUF conversion for granitemoeshared Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * feat: hparam and arch plumbing for granitemoeshared Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Split MoE fused tensors for shared experts in conversion Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * feat: First WIP cut at model arch in cpp The hparam and architecture plumbing should be correct, but the implementation of the shared experts seems to still be broken. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Cleaner (maybe more correct?) splitting for gate/up Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Fix the input to the shared experts I had misread that the shared experts take the inputs _before_ the standard MoE layer and was feeding the output of the MoE to the shared experts. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Avoid architecture-specific checks for Granite MoE Shared This is a cleaner way that will allow more flexibility in architecture strings going forward. Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * refactor: Split granite architectures out of llm_build_llama This helps de-clutter the llama-family graph construction and allows granite to diverge further (in preparation for Granite 4). NOTE: I removed the granite scale factors from llm_build_deci because they appear to only be there as copy-paste from llm_build_llama. The HF config does not seem to set those values: https://huggingface.co/Deci/DeciLM-7B/blob/main/config.json Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Fix compiler warning about uninitialized inp_pos This should not have been reachable, but it warns on some compliers Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Consoladate GraniteMoEShared into GraniteMoE for conversion Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart * fix: Consolidate GraniteMoEShared into GraniteMoE on the c++ side Branch: GraniteMoEShared Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart --- convert_hf_to_gguf.py | 22 ++- gguf-py/gguf/constants.py | 3 + gguf-py/gguf/tensor_mapping.py | 1 + src/llama-arch.cpp | 3 + src/llama-model.cpp | 241 ++++++++++++++++++++++++++++----- 5 files changed, 235 insertions(+), 35 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a34ba2988..68b5e8799 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5746,11 +5746,20 @@ class GraniteModel(LlamaModel): logger.info("gguf: (granite) logits_scale = %s", logits_scale) -@ModelBase.register("GraniteMoeForCausalLM") +@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") class GraniteMoeModel(GraniteModel): """Conversion for IBM's GraniteMoeForCausalLM""" model_arch = gguf.MODEL_ARCH.GRANITE_MOE + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: """In modeling_granitemoe, the JetMoe implementation of parallel experts is used. This essentially merges w1 and w3 into a single tensor with 2x @@ -5761,12 +5770,21 @@ class GraniteMoeModel(GraniteModel): if name.endswith("block_sparse_moe.input_linear.weight"): ffn_dim = self.hparams["intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), ] + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + return super().modify_tensors(data_torch, name, bid) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0e6226b90..21af0a9a2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1905,6 +1905,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index ecf21b2b4..2629b3c1a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -428,6 +428,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), MODEL_TENSOR.ATTN_Q_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f2bc8ca76..abf436ada 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1481,6 +1481,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3a4e72a36..f652f4b86 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1389,6 +1389,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1772,6 +1775,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -4385,10 +4395,13 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4598,11 +4611,6 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -4674,11 +4682,6 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "ffn_moe_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4701,11 +4704,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4816,11 +4814,6 @@ struct llm_build_deci : public llm_graph_context { continue; } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - // modified to support attention-free layer of Llama-3_1-Nemotron-51B ggml_tensor * ffn_inp = cur; if (n_head > 0) { @@ -4844,11 +4837,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4871,11 +4859,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -12214,6 +12197,195 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + + if (!inp_pos) { + inp_pos = build_inp_pos(); + } + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12921,8 +13093,6 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: { llm = std::make_unique(*this, params, gf); } break; @@ -13153,6 +13323,11 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); From bf7937112058f2815fc3825a9ff7b536ecafa3bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 13 May 2025 15:31:12 +0200 Subject: [PATCH 5/6] scripts : support arbitrary input file formats in compare-llama-bench.py (#13455) --- scripts/compare-llama-bench.py | 452 ++++++++++++++++++++++++--------- 1 file changed, 330 insertions(+), 122 deletions(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index c32b449f7..fc93bf62a 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -7,6 +7,10 @@ import sys import os from glob import glob import sqlite3 +import json +import csv +from typing import Optional, Union +from collections.abc import Iterator, Sequence try: import git @@ -17,6 +21,28 @@ except ImportError as e: logger = logging.getLogger("compare-llama-bench") +# All llama-bench SQL fields +DB_FIELDS = [ + "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", + "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", + "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", + "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", + "defrag_thold", + "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", + "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", +] + +DB_TYPES = [ + "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", + "REAL", + "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", +] +assert len(DB_FIELDS) == len(DB_TYPES) + # Properties by which to differentiate results per commit: KEY_PROPERTIES = [ "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type", @@ -42,7 +68,7 @@ DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default. GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables. MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"} -DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux): +DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux): $ git checkout master $ make clean && make llama-bench @@ -70,12 +96,13 @@ help_c = ( ) parser.add_argument("-c", "--compare", help=help_c) help_i = ( - "Input SQLite file for comparing commits. " + "JSON/JSONL/SQLite/CSV files for comparing commits. " + "Specify multiple times to use multiple input files (JSON/CSV only). " "Defaults to 'llama-bench.sqlite' in the current working directory. " "If no such file is found and there is exactly one .sqlite file in the current directory, " "that file is instead used as input." ) -parser.add_argument("-i", "--input", help=help_i) +parser.add_argument("-i", "--input", action="append", help=help_i) help_o = ( "Output format for the table. " "Defaults to 'pipe' (GitHub compatible). " @@ -110,119 +137,321 @@ if unknown_args: sys.exit(1) input_file = known_args.input -if input_file is None and os.path.exists("./llama-bench.sqlite"): - input_file = "llama-bench.sqlite" -if input_file is None: +if not input_file and os.path.exists("./llama-bench.sqlite"): + input_file = ["llama-bench.sqlite"] +if not input_file: sqlite_files = glob("*.sqlite") if len(sqlite_files) == 1: - input_file = sqlite_files[0] + input_file = sqlite_files -if input_file is None: +if not input_file: logger.error("Cannot find a suitable input file, please provide one.\n") parser.print_help() sys.exit(1) -connection = sqlite3.connect(input_file) -cursor = connection.cursor() -build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] -build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] +class LlamaBenchData: + repo: Optional[git.Repo] + build_len_min: int + build_len_max: int + build_len: int = 8 + builds: list[str] = [] + check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"]) -if build_len_min != build_len_max: - logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " - "Try purging the the database of old commits.") - cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});") + def __init__(self): + try: + self.repo = git.Repo(".", search_parent_directories=True) + except git.InvalidGitRepositoryError: + self.repo = None -build_len: int = build_len_min + def _builds_init(self): + self.build_len = self.build_len_min -builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() -builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] - -if not builds: - raise RuntimeError(f"{input_file} does not contain any builds.") - -try: - repo = git.Repo(".", search_parent_directories=True) -except git.InvalidGitRepositoryError: - repo = None - - -def find_parent_in_data(commit: git.Commit): - """Helper function to find the most recent parent measured in number of commits for which there is data.""" - heap: list[tuple[int, git.Commit]] = [(0, commit)] - seen_hexsha8 = set() - while heap: - depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:build_len] - if current_hexsha8 in builds: - return current_hexsha8 - for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:build_len] - if parent_hexsha8 not in seen_hexsha8: - seen_hexsha8.add(parent_hexsha8) - heapq.heappush(heap, (depth + 1, parent)) - return None - - -def get_all_parent_hexsha8s(commit: git.Commit): - """Helper function to recursively get hexsha8 values for all parents of a commit.""" - unvisited = [commit] - visited = [] - - while unvisited: - current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:build_len]) - for parent in current_commit.parents: - if parent.hexsha[:build_len] not in visited: - unvisited.append(parent) - - return visited - - -def get_commit_name(hexsha8: str): - """Helper function to find a human-readable name for a commit if possible.""" - if repo is None: - return hexsha8 - for h in repo.heads: - if h.commit.hexsha[:build_len] == hexsha8: - return h.name - for t in repo.tags: - if t.commit.hexsha[:build_len] == hexsha8: - return t.name - return hexsha8 - - -def get_commit_hexsha8(name: str): - """Helper function to search for a commit given a human-readable name.""" - if repo is None: + def _check_keys(self, keys: set) -> Optional[set]: + """Private helper method that checks against required data keys and returns missing ones.""" + if not keys >= self.check_keys: + return self.check_keys - keys return None - for h in repo.heads: - if h.name == name: - return h.commit.hexsha[:build_len] - for t in repo.tags: - if t.name == name: - return t.commit.hexsha[:build_len] - for c in repo.iter_commits("--all"): - if c.hexsha[:build_len] == name[:build_len]: - return c.hexsha[:build_len] - return None + + def find_parent_in_data(self, commit: git.Commit) -> Optional[str]: + """Helper method to find the most recent parent measured in number of commits for which there is data.""" + heap: list[tuple[int, git.Commit]] = [(0, commit)] + seen_hexsha8 = set() + while heap: + depth, current_commit = heapq.heappop(heap) + current_hexsha8 = commit.hexsha[:self.build_len] + if current_hexsha8 in self.builds: + return current_hexsha8 + for parent in commit.parents: + parent_hexsha8 = parent.hexsha[:self.build_len] + if parent_hexsha8 not in seen_hexsha8: + seen_hexsha8.add(parent_hexsha8) + heapq.heappush(heap, (depth + 1, parent)) + return None + + def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]: + """Helper method to recursively get hexsha8 values for all parents of a commit.""" + unvisited = [commit] + visited = [] + + while unvisited: + current_commit = unvisited.pop(0) + visited.append(current_commit.hexsha[:self.build_len]) + for parent in current_commit.parents: + if parent.hexsha[:self.build_len] not in visited: + unvisited.append(parent) + + return visited + + def get_commit_name(self, hexsha8: str) -> str: + """Helper method to find a human-readable name for a commit if possible.""" + if self.repo is None: + return hexsha8 + for h in self.repo.heads: + if h.commit.hexsha[:self.build_len] == hexsha8: + return h.name + for t in self.repo.tags: + if t.commit.hexsha[:self.build_len] == hexsha8: + return t.name + return hexsha8 + + def get_commit_hexsha8(self, name: str) -> Optional[str]: + """Helper method to search for a commit given a human-readable name.""" + if self.repo is None: + return None + for h in self.repo.heads: + if h.name == name: + return h.commit.hexsha[:self.build_len] + for t in self.repo.tags: + if t.name == name: + return t.commit.hexsha[:self.build_len] + for c in self.repo.iter_commits("--all"): + if c.hexsha[:self.build_len] == name[:self.build_len]: + return c.hexsha[:self.build_len] + return None + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + """Helper method that gets rows of (build_commit, test_time) sorted by the latter.""" + return [] + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + """ + Helper method that gets table rows for some list of properties. + Rows are created by combining those where all provided properties are equal. + The resulting rows are then grouped by the provided properties and the t/s values are averaged. + The returned rows are unique in terms of property combinations. + """ + return [] + + +class LlamaBenchDataSQLite3(LlamaBenchData): + connection: sqlite3.Connection + cursor: sqlite3.Cursor + + def __init__(self): + super().__init__() + self.connection = sqlite3.connect(":memory:") + self.cursor = self.connection.cursor() + self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});") + + def _builds_init(self): + if self.connection: + self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] + self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] + + if self.build_len_min != self.build_len_max: + logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " + "Try purging the the database of old commits.") + self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});") + + builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() + self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] + super()._builds_init() + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + data = self.cursor.execute( + "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() + return reversed(data) if reverse else data + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + select_string = ", ".join( + [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) + equal_string = " AND ".join( + [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ + f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] + ) + group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) + query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " + f"GROUP BY {group_order_string} ORDER BY {group_order_string};") + return self.cursor.execute(query).fetchall() + + +class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + self.connection.close() + self.connection = sqlite3.connect(data_file) + self.cursor = self.connection.cursor() + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + connection = sqlite3.connect(data_file) + cursor = connection.cursor() + + try: + if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0: + raise sqlite3.DatabaseError("The provided input file does not exist or is empty.") + except sqlite3.DatabaseError as e: + logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e) + cursor = None + + connection.close() + return True if cursor else False + + +class LlamaBenchDataJSONL(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + with open(data_file, "r", encoding="utf-8") as fp: + for i, line in enumerate(fp): + parsed = json.loads(line) + + for k in parsed.keys() - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(parsed.keys())): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for line in fp: + json.loads(line) + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataJSON(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + parsed = json.load(fp) + + for i, entry in enumerate(parsed): + for k in entry.keys() - set(DB_FIELDS): + del entry[k] + + if (missing_keys := self._check_keys(entry.keys())): + raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + json.load(fp) + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataCSV(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + for i, parsed in enumerate(csv.DictReader(fp)): + keys = set(parsed.keys()) + + for k in keys - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(keys)): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for parsed in csv.DictReader(fp): + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e) + return False + + return True + + +bench_data = None +if len(input_file) == 1: + if LlamaBenchDataSQLite3File.valid_format(input_file[0]): + bench_data = LlamaBenchDataSQLite3File(input_file[0]) + elif LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataJSONL.valid_format(input_file[0]): + bench_data = LlamaBenchDataJSONL(input_file[0]) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) +else: + if LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) + +if not bench_data: + raise RuntimeError("No valid (or some invalid) input files found.") + +if not bench_data.builds: + raise RuntimeError(f"{input_file} does not contain any builds.") hexsha8_baseline = name_baseline = None # If the user specified a baseline, try to find a commit for it: if known_args.baseline is not None: - if known_args.baseline in builds: + if known_args.baseline in bench_data.builds: hexsha8_baseline = known_args.baseline if hexsha8_baseline is None: - hexsha8_baseline = get_commit_hexsha8(known_args.baseline) + hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline) name_baseline = known_args.baseline if hexsha8_baseline is None: logger.error(f"cannot find data for baseline={known_args.baseline}.") sys.exit(1) # Otherwise, search for the most recent parent of master for which there is data: -elif repo is not None: - hexsha8_baseline = find_parent_in_data(repo.heads.master.commit) +elif bench_data.repo is not None: + hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit) if hexsha8_baseline is None: logger.error("No baseline was provided and did not find data for any master branch commits.\n") @@ -235,27 +464,25 @@ else: sys.exit(1) -name_baseline = get_commit_name(hexsha8_baseline) +name_baseline = bench_data.get_commit_name(hexsha8_baseline) hexsha8_compare = name_compare = None # If the user has specified a compare value, try to find a corresponding commit: if known_args.compare is not None: - if known_args.compare in builds: + if known_args.compare in bench_data.builds: hexsha8_compare = known_args.compare if hexsha8_compare is None: - hexsha8_compare = get_commit_hexsha8(known_args.compare) + hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare) name_compare = known_args.compare if hexsha8_compare is None: logger.error(f"cannot find data for compare={known_args.compare}.") sys.exit(1) # Otherwise, search for the commit for llama-bench was most recently run # and that is not a parent of master: -elif repo is not None: - hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit) - builds_timestamp = cursor.execute( - "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() - for (hexsha8, _) in reversed(builds_timestamp): +elif bench_data.repo is not None: + hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit) + for (hexsha8, _) in bench_data.builds_timestamp(reverse=True): if hexsha8 not in hexsha8s_master: hexsha8_compare = hexsha8 break @@ -270,26 +497,7 @@ else: parser.print_help() sys.exit(1) -name_compare = get_commit_name(hexsha8_compare) - - -def get_rows(properties): - """ - Helper function that gets table rows for some list of properties. - Rows are created by combining those where all provided properties are equal. - The resulting rows are then grouped by the provided properties and the t/s values are averaged. - The returned rows are unique in terms of property combinations. - """ - select_string = ", ".join( - [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) - equal_string = " AND ".join( - [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ - f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] - ) - group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) - query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " - f"GROUP BY {group_order_string} ORDER BY {group_order_string};") - return cursor.execute(query).fetchall() +name_compare = bench_data.get_commit_name(hexsha8_compare) # If the user provided columns to group the results by, use them: @@ -303,10 +511,10 @@ if known_args.show is not None: logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}") parser.print_usage() sys.exit(1) - rows_show = get_rows(show) + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) # Otherwise, select those columns where the values are not all the same: else: - rows_full = get_rows(KEY_PROPERTIES) + rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare) properties_different = [] for i, kp_i in enumerate(KEY_PROPERTIES): if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]: @@ -336,7 +544,7 @@ else: show.remove(prop) except ValueError: pass - rows_show = get_rows(show) + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) if not rows_show: logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n") From b4726345aca49e2ad62d615e6e370b3dbad6434f Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 13 May 2025 15:33:58 +0200 Subject: [PATCH 6/6] =?UTF-8?q?mtmd=20:=20remove=20libllava,=20remove=20cl?= =?UTF-8?q?ip-quantize-cli=20(=E2=9A=A0=EF=B8=8F=20breaking=20change)=20(#?= =?UTF-8?q?13460)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mtmd : remove libllava, remove clip-quantize-cli * rm clip_model_quantize --- tools/mtmd/CMakeLists.txt | 35 -- tools/mtmd/README-quantize.md | 44 -- tools/mtmd/README.md | 7 +- tools/mtmd/android/adb_run.sh | 53 -- tools/mtmd/android/build_64.sh | 8 - tools/mtmd/clip-quantize-cli.cpp | 59 -- tools/mtmd/clip.cpp | 135 ---- .../convert_image_encoder_to_gguf.py | 0 .../glmedge-convert-image-encoder-to-gguf.py | 0 .../{ => legacy-models}/glmedge-surgery.py | 0 .../mtmd/{ => legacy-models}/llava_surgery.py | 0 .../{ => legacy-models}/llava_surgery_v2.py | 0 .../minicpmv-convert-image-encoder-to-gguf.py | 0 .../{ => legacy-models}/minicpmv-surgery.py | 0 tools/mtmd/llava.cpp | 591 ------------------ tools/mtmd/llava.h | 49 -- 16 files changed, 4 insertions(+), 977 deletions(-) delete mode 100644 tools/mtmd/README-quantize.md delete mode 100755 tools/mtmd/android/adb_run.sh delete mode 100755 tools/mtmd/android/build_64.sh delete mode 100644 tools/mtmd/clip-quantize-cli.cpp rename tools/mtmd/{ => legacy-models}/convert_image_encoder_to_gguf.py (100%) rename tools/mtmd/{ => legacy-models}/glmedge-convert-image-encoder-to-gguf.py (100%) rename tools/mtmd/{ => legacy-models}/glmedge-surgery.py (100%) rename tools/mtmd/{ => legacy-models}/llava_surgery.py (100%) rename tools/mtmd/{ => legacy-models}/llava_surgery_v2.py (100%) rename tools/mtmd/{ => legacy-models}/minicpmv-convert-image-encoder-to-gguf.py (100%) rename tools/mtmd/{ => legacy-models}/minicpmv-surgery.py (100%) delete mode 100644 tools/mtmd/llava.cpp delete mode 100644 tools/mtmd/llava.h diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index dfafa9cf8..e7ba23587 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -1,29 +1,3 @@ -# llava (legacy) - -add_library(llava OBJECT - llava.cpp - llava.h - clip.cpp - clip.h - ) - -target_link_libraries(llava PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - -target_include_directories(llava PUBLIC .) -target_include_directories(llava PUBLIC ../..) -target_include_directories(llava PUBLIC ../../common) - -target_compile_features(llava PRIVATE cxx_std_17) - -add_library(llava_static STATIC $) -if (BUILD_SHARED_LIBS) - set_target_properties(llava PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(llava PRIVATE LLAMA_SHARED LLAMA_BUILD) - add_library(llava_shared SHARED $) - target_link_libraries(llava_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - install(TARGETS llava_shared LIBRARY) -endif() - # mtmd add_library(mtmd OBJECT @@ -53,12 +27,10 @@ if (BUILD_SHARED_LIBS) endif() if (NOT MSVC) - target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h endif() if(TARGET BUILD_INFO) - add_dependencies(llava BUILD_INFO) add_dependencies(mtmd BUILD_INFO) endif() @@ -73,10 +45,3 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-llava-clip-quantize-cli) -add_executable(${TARGET} clip-quantize-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/mtmd/README-quantize.md b/tools/mtmd/README-quantize.md deleted file mode 100644 index b931513ab..000000000 --- a/tools/mtmd/README-quantize.md +++ /dev/null @@ -1,44 +0,0 @@ -# Quantizing CLIP Visual Projector - -This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance. - -## Usage - -To quantize a CLIP visual projector model, use the following command: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf -``` - -After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc). - -### Arguments - -- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format. -- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved. -- ``: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`. - -### Quantization Types - -The following quantization types are supported, based on the `enum ggml_type` definition: - -- `2` - `q4_0`: 4-bit quantization with a single scale value. -- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block. -- `6` - `q5_0`: 5-bit quantization with a single scale value. -- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block. -- `8` - `q8_0`: 8-bit quantization with a single scale value. - -### Example - -To quantize a model using the `q4_0` quantization type, you would run: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2 -``` - -This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method. - -## Notes - -- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements. -- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments. diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index ab258ea17..ef31d1957 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -41,8 +41,8 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta Multimodal projector (`mmproj`) files are specific to each model architecture. -For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file: -- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support +For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` flag to get the `mmproj` file: +- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) ; See the guide [here](../../docs/multimodal/gemma3.md) - Note: 1B variant does not have vision support - SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint @@ -52,6 +52,8 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla For older models, please refer to the relevant guide for instructions on how to obtain or create them: +NOTE: conversion scripts are located under `tools/mtmd/legacy-models` + - [LLaVA](../../docs/multimodal/llava.md) - [MobileVLM](../../docs/multimodal/MobileVLM.md) - [GLM-Edge](../../docs/multimodal/glmedge.md) @@ -59,4 +61,3 @@ For older models, please refer to the relevant guide for instructions on how to - [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md) - [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md) - [IBM Granite Vision](../../docs/multimodal/granitevision.md) -- [Google Gemma 3](../../docs/multimodal/gemma3.md) diff --git a/tools/mtmd/android/adb_run.sh b/tools/mtmd/android/adb_run.sh deleted file mode 100755 index a24d6787d..000000000 --- a/tools/mtmd/android/adb_run.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -model_dir="/Users/cxt/model/llm/mobileVLM/MobileVLM-1.7B_processed" -projector_name="mmproj-model-f16.gguf" -llama_name="ggml-model-q4_k.gguf" -img_dir="/Users/cxt/model/llm" -img_name="demo.jpg" -prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" -# img_name="cat.jpeg" -# prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" - -program_dir="build_64/bin" -binName="llama-mtmd-cli" -n_threads=4 - - -deviceDir="/data/local/tmp" -saveDir="output" -if [ ! -d ${saveDir} ]; then - mkdir ${saveDir} -fi - - -function android_run() { - # # copy resource into device - # adb push ${model_dir}/${projector_name} ${deviceDir}/${projector_name} - # adb push ${model_dir}/${llama_name} ${deviceDir}/${llama_name} - adb push ${img_dir}/${img_name} ${deviceDir}/${img_name} - # copy program into device - adb push ${program_dir}/${binName} ${deviceDir}/${binName} - adb shell "chmod 0777 ${deviceDir}/${binName}" - - # run - adb shell "echo cd ${deviceDir} ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - > ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt" - adb shell "cd ${deviceDir}; pwd; ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - >> ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt 2>&1" - adb pull ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt ${saveDir} -} - -android_run - -echo "android_run is Done!" diff --git a/tools/mtmd/android/build_64.sh b/tools/mtmd/android/build_64.sh deleted file mode 100755 index 71b6fd3f7..000000000 --- a/tools/mtmd/android/build_64.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -cmake ../../../../ \ --DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ --DCMAKE_BUILD_TYPE=Release \ --DANDROID_ABI="arm64-v8a" \ --DANDROID_PLATFORM=android-23 $1 - -make -j4 diff --git a/tools/mtmd/clip-quantize-cli.cpp b/tools/mtmd/clip-quantize-cli.cpp deleted file mode 100644 index 566506954..000000000 --- a/tools/mtmd/clip-quantize-cli.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "arg.h" -#include "base64.hpp" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -static void print_usage(int argc, char ** argv) { - (void) argc; - - fprintf(stderr, "usage: %s /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 6 - q5_0\n"); - fprintf(stderr, " type = 7 - q5_1\n"); - fprintf(stderr, " type = 8 - q8_0\n"); -} - -int main(int argc, char ** argv) { - if (argc != 4) { - print_usage(argc, argv); - return 1; - } - - const std::string fname_inp = argv[1]; - const std::string fname_out = argv[2]; - - const int itype = atoi(argv[3]); - - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_quantize_us = 0; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!clip_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) { - fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); - return 1; - } - - t_quantize_us = ggml_time_us() - t_start_us; - } - - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - printf("\n"); - printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); - } - - return 0; -} diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 41ba45a79..a0f42e8c4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3586,141 +3586,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return true; } -bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { - assert(itype < GGML_TYPE_COUNT); - ggml_type type = static_cast(itype); - - auto * ctx_clip = clip_init(fname_inp, clip_context_params{ - /* use_gpu */ false, - /* verbosity */ GGML_LOG_LEVEL_ERROR, - }); - - const auto & ctx_src = ctx_clip->ctx_gguf.get(); - const auto & ctx_data = ctx_clip->ctx_data.get(); - - auto * ctx_out = gguf_init_empty(); - gguf_set_kv(ctx_out, ctx_src); - gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out, "general.file_type", itype); - - auto fout = std::ofstream(fname_out, std::ios::binary); - - const int n_tensors = gguf_get_n_tensors(ctx_src); - - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_src, i); - ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - gguf_add_tensor(ctx_out, cur); - } - - const size_t meta_size = gguf_get_meta_size(ctx_out); - for (size_t i = 0; i < meta_size; ++i) { - fout.put(0); - } - - // regexes of tensor names to be quantized - const std::vector k_names = { - ".*weight", - }; - - std::vector work(512); - std::vector conv_buf(512); - size_t total_size_org = 0; - size_t total_size_new = 0; - - for (int i = 0; i < n_tensors; ++i) { - const std::string name = gguf_get_tensor_name(ctx_src, i); - ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str()); - - enum ggml_type new_type; - void * new_data; - size_t new_size; - - bool quantize = false; - for (const auto & s : k_names) { - if (std::regex_match(name, std::regex(s))) { - quantize = true; - break; - } - } - - // quantize only 2D tensors and bigger than block size - quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); - - if (quantize) { - new_type = type; - if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) { - new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type - // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); - } - const size_t n_elms = ggml_nelements(cur); - float * f32_data; - - switch (cur->type) { - case GGML_TYPE_F32: - f32_data = (float *)cur->data; - break; - case GGML_TYPE_F16: - if (conv_buf.size() < n_elms) { - conv_buf.resize(n_elms); - } - for (size_t j = 0; j < n_elms; ++j) { - conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]); - } - f32_data = (float *)conv_buf.data(); - break; - default: - LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__); - gguf_free(ctx_out); - return false; - } - - if (work.size() < n_elms * 4) { - work.resize(n_elms * 4); - } - new_data = work.data(); - - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr); - } else { - new_type = cur->type; - new_data = cur->data; - new_size = ggml_nbytes(cur); - } - const size_t orig_size = ggml_nbytes(cur); - total_size_org += orig_size; - total_size_new += new_size; - gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data); - fout.write((const char *)new_data, new_size); - size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; - for (size_t j = 0; j < pad; ++j) { - fout.put(0); - } - - LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize, - orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); - } - - // go back to beginning of file and write the updated metadata - fout.seekp(0, std::ios::beg); - std::vector meta(meta_size); - gguf_get_meta_data(ctx_out, meta.data()); - fout.write((const char *)meta.data(), meta_size); - - fout.close(); - - clip_free(ctx_clip); - gguf_free(ctx_out); - - { - LOG_INF("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); - LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); - } - - return true; -} - int clip_n_mmproj_embd(const struct clip_ctx * ctx) { switch (ctx->proj_type) { case PROJECTOR_TYPE_LDP: diff --git a/tools/mtmd/convert_image_encoder_to_gguf.py b/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py similarity index 100% rename from tools/mtmd/convert_image_encoder_to_gguf.py rename to tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py diff --git a/tools/mtmd/glmedge-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py similarity index 100% rename from tools/mtmd/glmedge-convert-image-encoder-to-gguf.py rename to tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py diff --git a/tools/mtmd/glmedge-surgery.py b/tools/mtmd/legacy-models/glmedge-surgery.py similarity index 100% rename from tools/mtmd/glmedge-surgery.py rename to tools/mtmd/legacy-models/glmedge-surgery.py diff --git a/tools/mtmd/llava_surgery.py b/tools/mtmd/legacy-models/llava_surgery.py similarity index 100% rename from tools/mtmd/llava_surgery.py rename to tools/mtmd/legacy-models/llava_surgery.py diff --git a/tools/mtmd/llava_surgery_v2.py b/tools/mtmd/legacy-models/llava_surgery_v2.py similarity index 100% rename from tools/mtmd/llava_surgery_v2.py rename to tools/mtmd/legacy-models/llava_surgery_v2.py diff --git a/tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py similarity index 100% rename from tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py rename to tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py diff --git a/tools/mtmd/minicpmv-surgery.py b/tools/mtmd/legacy-models/minicpmv-surgery.py similarity index 100% rename from tools/mtmd/minicpmv-surgery.py rename to tools/mtmd/legacy-models/minicpmv-surgery.py diff --git a/tools/mtmd/llava.cpp b/tools/mtmd/llava.cpp deleted file mode 100644 index ebef8b3c1..000000000 --- a/tools/mtmd/llava.cpp +++ /dev/null @@ -1,591 +0,0 @@ -#include "clip.h" -#include "llava.h" - -#include "llama.h" -#include "ggml-cpp.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(LLAVA_LOG_OFF) -# define LOG_INF(...) -# define LOG_WRN(...) -# define LOG_ERR(...) -# define LOG_DBG(...) -#else // defined(LLAVA_LOG_OFF) -# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#endif // defined(LLAVA_LOG_OFF) - -// RGB uint8 image -struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... -struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - -struct clip_image_grid_shape { - int first; - int second; -}; - -// convenience cpp wrapper -struct clip_image_f32_batch_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_f32_batch_ptr; - -struct clip_image_size_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_size_ptr; - -/** - * Selects the best resolution from a list of possible resolutions based on the original size. - * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). - */ -static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; - } - } - - return best_fit; -} - -/** - * @brief Get the anyres image grid shape object - * - * @param image_size - * @param grid_pinpoints - * @param image_patch_size - * @return - */ -static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair & image_size, const std::vector> & grid_pinpoints, int image_patch_size) { - /** - Conversion from gguf flat array to vector: - std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - */ - auto best_resolution = select_best_resolution(image_size, grid_pinpoints); - return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; -} - -// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) -static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) { - struct { - struct ggml_context * ctx; - } model; - - const int32_t image_size = clip_get_image_size(ctx_clip); - const int32_t patch_size = clip_get_patch_size(ctx_clip); - - int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) - - int num_patches_width = grid_shape.first; // grid 1-4 - int num_patches_height = grid_shape.second; // grid 1-4 - - const size_t num_images = num_patches_width * num_patches_height + 1; - - // TODO: size calculation is not calculated - it's only tens of MB - size_t ctx_size = 0; - - { - ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features - ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); - } - - struct ggml_init_params params { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API - }; - - // Python reference code for full unpad: - /* - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1) - ), dim=-1) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - */ - // We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval. - // In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet. - // Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them. - // Once all images are processed to prepended the base_image_features without any changes. - - // Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling)) - /* - image_feature = image_feature.view(2, 2, 24, 24, 4096) - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.view(2, 24, 2, 24, 4096) - image_feature = image_feature.flatten(0, 3) - - // Reshape to 4D tensor by merging the last two dimensions - image_feature = image_feature.view(2, 2, 24, 24*4096) - image_feature = image_feature.permute(0, 2, 1, 3).contiguous() - image_feature = image_feature.view(-1, 4096) - */ - - model.ctx = ggml_init(params); - - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4 - // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); - // fill it with the image embeddings, ignoring the base - for (size_t i = 1; i < num_images; i++) { - size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); - memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); - } - - struct ggml_cgraph * gf = ggml_new_graph(model.ctx); - size_t size_ele = ggml_type_size(GGML_TYPE_F32); - - struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features, - num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - num_patches_per_side, - num_patches_width, - num_patches_height, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0); - // ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false); - struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); - /** - At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) - ), dim=-1) - * - */ - - // ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false); - struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); - // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); - ggml_build_forward_expand(gf, flatten); - - ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; - GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend"); - ggml_backend_graph_compute(backend.get(), gf); - - struct ggml_tensor* result = ggml_graph_node(gf, -1); - - memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context - // append without newline tokens (default behavior in llava_arch when not using unpad ): - memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches - *n_img_pos_out = static_cast(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input)); - - // Debug: Test single segments - // Current findings: sending base image, sending a segment embedding all works similar to python - // However, permuted embeddings do not work yet (stride issue?) - // memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context - // memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context - // *n_img_pos_out=576; - - ggml_free(model.ctx); - return true; -} - -static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { - int width = image->nx; - int height = image->ny; - int num_patches = (height / patch_size) * (width / patch_size); - clip_image_f32 * patch = clip_image_f32_init(); - patch->nx = patch_size * num_patches; - patch->ny = patch_size; - patch->buf.resize(3 * patch->nx * patch->ny); - - int patch_index = 0; - - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - for (int pi = 0; pi < patch_size; ++pi) { - for (int pj = 0; pj < patch_size; ++pj) { - int input_index = ((i + pi) * width + (j + pj)) * 3; - int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3; - patch->buf[output_index] = image->buf[input_index]; - patch->buf[output_index+1] = image->buf[input_index+1]; - patch->buf[output_index+2] = image->buf[input_index+2]; - } - } - patch_index++; - } - } - return patch; -} - -static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { - // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); - if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { - LOG_ERR("%s: unable to preprocess image\n", __func__); - return false; - } - - const int64_t t_img_enc_start_us = ggml_time_us(); - - const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - - const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); - - if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - clip_image_size load_image_size; - - for (size_t i = 0; i < n_imgs; i++) { - const int64_t t_img_enc_step_start_us = ggml_time_us(); - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - int patch_size = 14; - load_image_size.width = nx; - load_image_size.height = ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - - bool encoded = false; - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); - } - else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); - } - - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - int n_img_pos_out = 0; - for (size_t i = 0; i < image_embd_v.size(); i++) { - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - std::memcpy( - image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), - image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res); - } - *n_img_pos = n_img_pos_out; - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - load_image_size.width = img->nx; - load_image_size.height = img->ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); - } - else if (clip_is_glm(ctx_clip)){ - struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); - load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); - clip_add_load_image_size(ctx_clip, load_image_size); - - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); - int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); - *n_img_pos = (pos * pos + 2); - if (!encoded){ - LOG_ERR("Unable to encode image \n"); - return false; - } - } - else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { - // flat / default llava-1.5 type embedding - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - *n_img_pos = clip_n_output_tokens(ctx_clip, img_res); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 - if (!encoded) { - LOG_ERR("Unable to encode image\n"); - - return false; - } - } - else { - // spatial_unpad llava-1.6 type embedding - // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - for (size_t i = 0; i < n_imgs; i++) { - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - const int32_t * image_grid = clip_image_grid(ctx_clip); - const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); - - std::vector> grid_pinpoints; - for (size_t i = 0; i < num_gridpoints; i += 2) { - grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); - } - - const int32_t image_size = clip_get_image_size(ctx_clip); - - struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); - - int n_img_pos_out; - clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0); - clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input); - *n_img_pos = n_img_pos_out; - - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - - // debug image/segment/normalization content: - // clip_image_u8 * tmp = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*image_feature, *tmp); - // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); - } - - LOG_INF("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); - - const int64_t t_img_enc_end_us = ggml_time_us(); - float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - - LOG_INF("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); - - return true; -} - -bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { - // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - auto n_image_embd = clip_n_mmproj_embd(ctx_clip); - if (n_image_embd != n_llama_embd) { - LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); - return false; - } - return true; -} - -bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - // Granite vision uses up to 10 patches + base patch - int num_max_patches = 11; - if (clip_is_minicpmv(ctx_clip)) { - num_max_patches = 10; - } - if (clip_is_glm(ctx_clip)) { - num_max_patches = 1; - } - float * image_embd; - if (clip_is_qwen2vl(ctx_clip)) { - // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. - image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny)); - } else { - image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model - } - if (!image_embd) { - LOG_ERR("Unable to allocate memory for image embeddings\n"); - return false; - } - - int n_img_pos; - if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) { - LOG_ERR("%s: cannot encode image, aborting\n", __func__); - free(image_embd); - return false; - } - *image_embd_out = image_embd; - *n_img_pos_out = n_img_pos; - - return true; -} - -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - -bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - - for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { - int n_eval = image_embed->n_image_pos - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { - LOG_ERR("%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) { - clip_image_u8 * img = clip_image_u8_init(); - if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { - clip_image_u8_free(img); - LOG_ERR("%s: can't load image from bytes, is it a valid image?", __func__); - return NULL; - } - - float* image_embed = NULL; - int n_image_pos = 0; - bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); - if (!image_embed_result) { - clip_image_u8_free(img); - LOG_ERR("%s: couldn't embed the image\n", __func__); - return NULL; - } - - clip_image_u8_free(img); - auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); - result->embed = image_embed; - result->n_image_pos = n_image_pos; - return result; -} - -static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) { - auto file = fopen(path, "rb"); - if (file == NULL) { - LOG_ERR("%s: can't read file %s\n", __func__, path); - return false; - } - - fseek(file, 0, SEEK_END); - auto fileSize = ftell(file); - fseek(file, 0, SEEK_SET); - - auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data - if (buffer == NULL) { - LOG_ERR("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); - perror("Memory allocation error"); - fclose(file); - return false; - } - errno = 0; - size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer - if (ferror(file)) { - LOG_ERR("read error: %s", strerror(errno)); - free(buffer); - fclose(file); - return false; - } - if (ret != (size_t) fileSize) { - LOG_ERR("unexpectedly reached end of file"); - free(buffer); - fclose(file); - return false; - } - fclose(file); // Close the file - - *bytesOut = buffer; - *sizeOut = fileSize; - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) { - unsigned char* image_bytes; - long image_bytes_length; - auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); - if (!loaded) { - LOG_ERR("%s: failed to load %s\n", __func__, image_path); - return NULL; - } - - llava_image_embed *embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length); - free(image_bytes); - - return embed; -} - -void llava_image_embed_free(struct llava_image_embed * embed) { - free(embed->embed); - free(embed); -} diff --git a/tools/mtmd/llava.h b/tools/mtmd/llava.h deleted file mode 100644 index b6feb3027..000000000 --- a/tools/mtmd/llava.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef LLAVA_H -#define LLAVA_H - -#include "ggml.h" - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define LLAVA_API __declspec(dllexport) -# else -# define LLAVA_API __declspec(dllimport) -# endif -# else -# define LLAVA_API __attribute__ ((visibility ("default"))) -# endif -#else -# define LLAVA_API -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -struct clip_ctx; -struct llava_image_embed { - float * embed; - int n_image_pos; -}; - -/** sanity check for clip <-> llava embed size match */ -LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip); - -LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); - -/** build an image embed from image file bytes */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -/** build an image embed from a path to an image filename */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -/** free an embedding made with llava_image_embed_make_* */ -LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); - -/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); - -#ifdef __cplusplus -} -#endif - -#endif