From 3d26a09dc7b1a7c13da57fdd26d1cf22efa81229 Mon Sep 17 00:00:00 2001 From: R Date: Tue, 6 Jan 2026 16:17:13 +0100 Subject: [PATCH 01/19] server : add thinking content blocks to Anthropic Messages API (#18551) * server : add thinking content blocks to Anthropic Messages API Add support for returning reasoning/thinking content in Anthropic API responses when using models with --reasoning-format deepseek and the thinking parameter enabled. - Non-streaming: adds thinking block before text in content array - Streaming: emits thinking_delta events with correct block indices - Partial streaming: tracks reasoning state across chunks via anthropic_has_reasoning member variable Tested with bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF model. * server : fix Anthropic API streaming for thinking content blocks Add signature field and fix duplicate content_block_start events in Anthropic Messages API streaming responses for reasoning models. * server: refactor Anthropic streaming state to avoid raw pointer Replace raw pointer to task_result_state with direct field copies: - Copy state fields in update() before processing chunk - Use local copies in to_json_anthropic() instead of dereferencing - Pre-compute state updates for next chunk in update() This makes the data flow clearer and avoids unsafe pointer patterns. --- tools/server/server-task.cpp | 149 +++++++++++++++--- tools/server/server-task.h | 26 +++ .../tests/unit/test_compat_anthropic.py | 89 +++++++++++ 3 files changed, 246 insertions(+), 18 deletions(-) diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 6d374131e..ed4f6546e 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -814,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() { msg.content = content; } + // thinking block comes first (Anthropic extended thinking format) + if (!msg.reasoning_content.empty()) { + content_blocks.push_back({ + {"type", "thinking"}, + {"thinking", msg.reasoning_content}, + {"signature", ""} // empty signature for local models (no cryptographic verification) + }); + } + if (!msg.content.empty()) { content_blocks.push_back({ {"type", "text"}, @@ -862,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; } - bool has_text = !oaicompat_msg.content.empty(); + bool has_thinking = !oaicompat_msg.reasoning_content.empty(); + bool has_text = !oaicompat_msg.content.empty(); size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - bool text_block_started = false; + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + size_t text_block_index = has_thinking ? 1 : 0; + + bool thinking_block_started = false; + bool text_block_started = false; std::unordered_set tool_calls_started; for (const auto & diff : oaicompat_msg_diffs) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", thinking_block_index}, + {"content_block", { + {"type", "thinking"}, + {"thinking", ""} + }} + }} + }); + thinking_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content if (!diff.content_delta.empty()) { if (!text_block_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", text_block_index}, {"content_block", { {"type", "text"}, {"text", ""} @@ -889,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -898,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index; if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; @@ -935,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { } } + // close content blocks in order + if (has_thinking) { + // Anthropic API requires a signature_delta before closing thinking blocks + // We use an empty signature since we can't generate a cryptographic signature for local models + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "signature_delta"}, + {"signature", ""} + }} + }} + }); + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", thinking_block_index} + }} + }); + } + if (has_text) { events.push_back({ {"event", "content_block_stop"}, {"data", { {"type", "content_block_stop"}, - {"index", 0} + {"index", text_block_index} }} }); } for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i; events.push_back({ {"event", "content_block_stop"}, {"data", { @@ -1154,11 +1225,10 @@ json server_task_result_rerank::to_json() { json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); - bool text_block_started = false; + // use member variables to track block state across streaming calls + // (anthropic_thinking_block_started, anthropic_text_block_started) if (first) { - text_block_started = false; - events.push_back({ {"event", "message_start"}, {"data", { @@ -1180,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + // use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated + size_t text_block_index = anthropic_has_reasoning ? 1 : 0; + + // use local copies of streaming state (copied from task_result_state in update()) + // these reflect the state BEFORE this chunk was processed + bool thinking_started = anthropic_thinking_block_started; + bool text_started = anthropic_text_block_started; + for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", thinking_block_index}, {"content_block", { - {"type", "text"}, - {"text", ""} + {"type", "thinking"}, + {"thinking", ""} }} }} }); - text_block_started = true; + thinking_started = true; } events.push_back({ {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content + if (!diff.content_delta.empty()) { + if (!text_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", text_block_index}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -1210,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + // use anthropic_has_reasoning for thinking block count (persists across calls) + size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index; if (!diff.tool_call_delta.name.empty()) { events.push_back({ diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 687770de5..ead149118 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -96,6 +96,10 @@ struct task_result_state { std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + // for Anthropic API streaming: track content block state across chunks + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) : oaicompat_chat_syntax(oaicompat_chat_syntax) {} @@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // for Anthropic API: track if any reasoning content has been generated + bool anthropic_has_reasoning = false; + // Streaming state copied from task_result_state for this chunk + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result { virtual void update(task_result_state & state) override { is_updated = true; state.update_chat_msg(content, true, oaicompat_msg_diffs); + // track if the accumulated message has any reasoning content + anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); + + // Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk) + anthropic_thinking_block_started = state.anthropic_thinking_block_started; + anthropic_text_block_started = state.anthropic_text_block_started; + + // Pre-compute state updates based on diffs (for next chunk) + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) { + state.anthropic_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !state.anthropic_text_block_started) { + state.anthropic_text_block_started = true; + } + } } json to_json_non_oaicompat(); diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index e0a003557..e16e0235c 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format(): assert "input_tokens" in anthropic_res.body["usage"] assert "completion_tokens" in openai_res.body["usage"] assert "output_tokens" in anthropic_res.body["usage"] + + +# Extended thinking tests with reasoning models + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [False, True]) +def test_anthropic_thinking_with_reasoning_model(stream): + """Test that thinking content blocks are properly returned for reasoning models""" + global server + server = ServerProcess() + server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF" + server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf" + server.reasoning_format = "deepseek" + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 1024 + server.server_port = 8084 + server.start(timeout_seconds=600) # large model needs time to download + + if stream: + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ], + "stream": True + }) + + events = list(res) + + # should have thinking content block events + thinking_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "thinking"] + assert len(thinking_starts) > 0, "Should have thinking content_block_start event" + assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0" + + # should have thinking_delta events + thinking_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "thinking_delta"] + assert len(thinking_deltas) > 0, "Should have thinking_delta events" + + # should have signature_delta event before thinking block closes (Anthropic API requirement) + signature_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "signature_delta"] + assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block" + + # should have text block after thinking + text_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "text"] + assert len(text_starts) > 0, "Should have text content_block_start event" + assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)" + else: + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + content = res.body["content"] + assert len(content) >= 2, "Should have at least thinking and text blocks" + + # first block should be thinking + thinking_blocks = [b for b in content if b.get("type") == "thinking"] + assert len(thinking_blocks) > 0, "Should have thinking content block" + assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field" + assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty" + assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)" + + # should also have text block + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) > 0, "Should have text content block" From 968929528c6a05e10249366fbe5f0330ad9af678 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 6 Jan 2026 07:26:07 -0800 Subject: [PATCH 02/19] mmq.cu: tune mmq/rocblas switching for RDNA (#18537) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Patch perf regression for mmq kernels in ROCm recover performance regression for https://github.com/ggml-org/llama.cpp/issues/17917 * add n_experts branch like the cdna path * mmq.cu: tune mmq/wmma switching for RDNA * mmq.cu: move amd wmma mmq/wmma switching behind IS_RDNA3 * Update ggml/src/ggml-cuda/mmq.cu Co-authored-by: Johannes Gäßler --------- Co-authored-by: Jiacheng (Jason) Chen <76919340+jiachengjason@users.noreply.github.com> Co-authored-by: jiachengjason Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cu | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 85692d454..ceb95758d 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } if (amd_wmma_available(cc)) { + // RDNA 4 is consistently worse on rocblas + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts almost always better on MMQ + // due to a large amount of graph splits + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + switch (type) { + // These quants are really bad on MMQ + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q6_K: + // These quants are usually worse but not always + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return ne11 <= 128; + default: + return true; + } + } return true; } From 090b137e56a80b189dbced7d31e637951f3e123f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 6 Jan 2026 23:48:45 +0800 Subject: [PATCH 03/19] ggml-cuda: refactor cuda graph usage (#18637) * ggml-cuda: refactor cuda graph usage * use is_enabled() instead of enabled --- ggml/src/ggml-cuda/common.cuh | 24 ++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 138 ++++++++++++-------------------- ggml/src/ggml-cuda/mean.cu | 6 +- 3 files changed, 72 insertions(+), 96 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 995b774c2..9516d8ec8 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { +struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; @@ -1061,11 +1061,25 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; - bool cuda_graphs_enabled = false; - std::vector ggml_graph_properties; - std::vector extraneous_srcs_properties; + std::vector props; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } #endif }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75269170c..bac69cdd1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; @@ -2915,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, return use_cuda_graph; } -static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node->data; - graph_node_properties->node_op = node->op; +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + props->node_address = node->data; + props->node_op = node->op; for (int i = 0; i < GGML_MAX_DIMS; i++) { - graph_node_properties->ne[i] = node->ne[i]; - graph_node_properties->nb[i] = node->nb[i]; + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } - memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_address && node->op != GGML_OP_VIEW) { return false; } - if (node->op != graph_node_properties->node_op) { + if (node->op != props->node_op) { return false; } for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { + if (node->ne[i] != props->ne[i]) { return false; } - if (node->nb[i] != graph_node_properties->nb[i]) { + if (node->nb[i] != props->nb[i]) { return false; } } for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + node->src[i]->data != props->src_address[i] && node->op != GGML_OP_VIEW ) { return false; @@ -2957,56 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool cuda_graph_update_required = false; + bool res = false; if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; + res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs); + if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + res = true; + cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); } for (int i = 0; i < cgraph->n_leafs; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + bool props_match= true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); } - return cuda_graph_update_required; + return res; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; @@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { + bool graph_evaluated_or_captured = false; + // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; @@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx); } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); @@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } -static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) { +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { #ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - // Objects required for CUDA Graph if (cuda_ctx->cuda_graph == nullptr) { cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } - bool use_cuda_graph = true; - if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif } } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph; + return cuda_ctx->cuda_graph->is_enabled(); #else - bool use_cuda_graph = false; + return false; #endif // USE_CUDA_GRAPH - - return use_cuda_graph; } static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; - // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called) - // we call it here instead. #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + if (cuda_ctx->cuda_graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); - - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } - - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; - cuda_ctx->cuda_graph->cuda_graphs_enabled = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } + cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); } #endif // USE_CUDA_GRAPH @@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 691d8dcb1..60542fc19 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // CUDA_GRAPHS_DISABLED ((ncols > 65536) && ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + ctx.cuda_graph->is_enabled())) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + ctx.cuda_graph->is_enabled()))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH From ea13cba85092fa17cae4d7bd064e5476a86ea53c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 6 Jan 2026 10:37:07 -0600 Subject: [PATCH 04/19] vulkan: support buffer_from_host_ptr (#18467) * vulkan: support buffer_from_host_ptr * hacky use of buffer_from_host_ptr for directio * disable buffer_from_host_ptr cap * use external memory for ggml_vk_host_malloc, revert model loader changes * disable external_memory_host for MoltenVK * take buffer memory types into account * don't use external_memory_host for ggml_vk_host_malloc --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 175 +++++++++++++++++++++++---- 1 file changed, 149 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 502a4deeb..3c13777b8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -550,6 +550,8 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t max_buffer_size; uint64_t suballocation_block_size; + uint64_t min_imported_host_pointer_alignment; + bool external_memory_host {}; bool fp16; bool bf16; bool pipeline_robustness; @@ -2410,7 +2412,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe return indices; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list, + void *import_ptr = nullptr) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_buffer_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); @@ -2439,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std nullptr, }; + vk::ExternalMemoryBufferCreateInfo external_memory_bci; + if (import_ptr) { + external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + buffer_create_info.setPNext(&external_memory_bci); + } + buf->buffer = device->device.createBuffer(buffer_create_info); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); @@ -2453,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std mem_flags_info.setPNext(&mem_priority_info); } - for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { - const auto & req_flags = *it; - - const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); - - if (memory_type_indices.empty()) { - continue; + if (import_ptr) { + vk::MemoryHostPointerPropertiesEXT host_pointer_props; + try { + host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); + device->device.destroyBuffer(buf->buffer); + return {}; } - buf->memory_property_flags = req_flags; + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - bool done = false; + uint32_t memory_type_idx; + vk::MemoryPropertyFlags property_flags = *req_flags_list.begin(); + for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) { + if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } + if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } - for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); - done = true; + vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; + // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed + if ((memory_type.propertyFlags & property_flags) == property_flags) { + property_flags = memory_type.propertyFlags; break; - } catch (const vk::SystemError& e) { - // loop and retry - // during last attempt throw the exception - if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { - device->device.destroyBuffer(buf->buffer); - throw e; - } } } + if (memory_type_idx == 32) { + GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); + device->device.destroyBuffer(buf->buffer); + return {}; + } - if (done) { - break; + buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags; + try { + vk::ImportMemoryHostPointerInfoEXT import_info; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + import_info.pHostPointer = import_ptr; + import_info.setPNext(&mem_flags_info); + buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info }); + } catch (const vk::SystemError& e) { + } + } else { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_indices.empty()) { + continue; + } + buf->memory_property_flags = req_flags; + + bool done = false; + + for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); + done = true; + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (done) { + break; + } } } @@ -2492,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->ptr = nullptr; - if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + if (import_ptr) { + buf->ptr = import_ptr; + } else { + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } } device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); @@ -4447,6 +4505,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { + device->external_memory_host = true; } } @@ -4461,6 +4521,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -4500,11 +4561,22 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; } + if (device->external_memory_host) { + last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props; + last_struct = (VkBaseOutStructure *)&external_memory_host_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + if (device->driver_id == vk::DriverId::eMoltenvk) { + // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622 + // is available in the Vulkan SDK. + device->external_memory_host = false; + } + // Implementing the async backend interfaces seems broken on older Intel HW, // see https://github.com/ggml-org/llama.cpp/issues/17302. device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || @@ -4586,6 +4658,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); @@ -4717,6 +4791,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_pipeline_executable_properties"); } + if (device->external_memory_host) { + device_extensions.push_back("VK_EXT_external_memory_host"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -14773,6 +14851,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); } +static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { + if (!device->external_memory_host) { + return {}; + } + + uintptr_t uptr = reinterpret_cast(ptr); + if (uptr & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + if (size & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + + const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached; + + vk_buffer buf {}; + try { + buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what()); + } + + return buf; +} + +static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")"); + GGML_UNUSED(max_tensor_size); + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size); + + if (!buf) { + return {}; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name); + + ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size); + + return ret; +} + static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .get_name = */ ggml_backend_vk_device_get_name, /* .get_description = */ ggml_backend_vk_device_get_description, @@ -14782,7 +14905,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .init_backend = */ ggml_backend_vk_device_init, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .offload_op = */ ggml_backend_vk_device_offload_op, From 07fbe19f1fbcfa09abca7cccc62eaf82c1567b7e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 6 Jan 2026 17:51:08 +0100 Subject: [PATCH 05/19] arg: use CSV escape style for multiple-value args (#18643) * arg: use CSV escape style for multiple-value args * add test --- common/arg.cpp | 107 ++++++++++++++++++++++++-------------- tests/test-arg-parser.cpp | 9 ++++ 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b52b3e70b..c3610d262 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } +// Simple CSV parser that handles quoted fields and escaped quotes +// example: +// input: value1,"value, with, commas","value with ""escaped"" quotes",value4 +// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4] +static std::vector parse_csv_row(const std::string& input) { + std::vector fields; + std::string field; + bool in_quotes = false; + + for (size_t i = 0; i < input.length(); ++i) { + char ch = input[i]; + + if (ch == '"') { + if (!in_quotes) { + // start of quoted field (only valid if at beginning of field) + if (!field.empty()) { + // quote appeared in middle of unquoted field, treat as literal + field += '"'; + } else { + in_quotes = true; // start + } + } else { + if (i + 1 < input.length() && input[i + 1] == '"') { + // escaped quote: "" + field += '"'; + ++i; // skip the next quote + } else { + in_quotes = false; // end + } + } + } else if (ch == ',') { + if (in_quotes) { + field += ','; + } else { + fields.push_back(std::move(field)); + field.clear(); + } + } else { + field += ch; + } + } + + // Add the last field + fields.push_back(std::move(field)); + + return fields; +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // per-example default params // we define here to make sure it's included in llama-gen-docs @@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--in-file"}, "FNAME", "an input file (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2002,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.image.emplace_back(item); } } @@ -2259,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--override-kv"}, "KEY=TYPE:VALUE,...", - "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n" + "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n" "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false", [](common_params & params, const std::string & value) { - std::vector kv_overrides; - - std::string current; - bool escaping = false; - - for (const char c : value) { - if (escaping) { - current.push_back(c); - escaping = false; - } else if (c == '\\') { - escaping = true; - } else if (c == ',') { - kv_overrides.push_back(current); - current.clear(); - } else { - current.push_back(c); - } - } - - if (escaping) { - current.push_back('\\'); - } - - kv_overrides.push_back(current); - - for (const auto & kv_override : kv_overrides) { - if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) { - throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str())); + for (const auto & item : parse_csv_row(value)) { + if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str())); } } } @@ -2306,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (use comma-separated values to load multiple adapters)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); } } @@ -2317,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n" "note: use comma-separated values", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); @@ -2331,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--control-vector"}, "FNAME", "add a control vector\nnote: use comma-separated values to add multiple control vectors", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.control_vectors.push_back({ 1.0f, item, }); } } @@ -2341,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "add a control vector with user defined scaling SCALE\n" "note: use comma-separated values (format: FNAME:SCALE,...)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE"); @@ -2439,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--context-file"}, "FNAME", "file to load context from (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item, std::ios::binary); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2675,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(common_arg( {"--api-key"}, "KEY", - "API key to use for authentication (default: none)", + "API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)", [](common_params & params, const std::string & value) { - params.api_keys.push_back(value); + for (const auto & key : parse_csv_row(value)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( @@ -2691,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::string key; while (std::getline(key_file, key)) { if (!key.empty()) { - params.api_keys.push_back(key); + params.api_keys.push_back(key); } } key_file.close(); @@ -2713,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); add_opt(common_arg( {"--chat-template-kwargs"}, "STRING", - string_format("sets additional params for the json template parser"), + "sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'", [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 1bbb745e7..e995974a2 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -127,6 +127,15 @@ int main(void) { assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); assert(params.speculative.n_max == 123); + // multi-value args (CSV) + argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"}; + assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + assert(params.lora_adapters.size() == 4); + assert(params.lora_adapters[0].path == "file1.gguf"); + assert(params.lora_adapters[1].path == "file2,2.gguf"); + assert(params.lora_adapters[2].path == "file3\"3\".gguf"); + assert(params.lora_adapters[3].path == "file4\".gguf"); + // skip this part on windows, because setenv is not supported #ifdef _WIN32 printf("test-arg-parser: skip on windows build\n"); From 24af22fc365ea6ef8e37875108a83658aa16fc8a Mon Sep 17 00:00:00 2001 From: Aadeshveer Singh <24b0926@iitb.ac.in> Date: Tue, 6 Jan 2026 23:54:34 +0530 Subject: [PATCH 06/19] ggml : optimize cuda ssm_scan using warp-level reduction (#18505) * ggml : optimize cuda ssm_scan using warp-level reduction * ggml : apply code review suggestions (style, const, constexpr) * ggml : add TODO regarding stride consistency --- ggml/src/ggml-cuda/ssm-scan.cu | 133 ++++++++++++--------------------- 1 file changed, 49 insertions(+), 84 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 6b424381d..c1d4e2bc8 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1) #endif // __clang__ // assumes as many threads as d_state -template +template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1) const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { - const int head_idx = (blockIdx.x * splitH) / d_head; - const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); - const int seq_idx = blockIdx.y; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; - float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); // strides across n_seq_tokens const int stride_x = src1_nb2 / sizeof(float); @@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1) const int stride_C = src5_nb2 / sizeof(float); const int stride_y = n_head * d_head; - float state[splitH]; - // for the parallel accumulation - __shared__ float stateC[splitH * d_state]; + float state[c_factor]; + float state_sum = 0.0f; #pragma unroll - for (int j = 0; j < splitH; j++) { - state[j] = s0_block[j * d_state + threadIdx.x]; + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; } for (int64_t i = 0; i < n_tok; i++) { - // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements - // TODO: only calculate B and C once per head group - // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. - float dt_soft_plus = dt_block[i * stride_dt]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(expf(dt_soft_plus)); - } - const float dA = expf(dt_soft_plus * A_block[0]); - const float B = B_block[i * stride_B + threadIdx.x]; - const float C = C_block[i * stride_C + threadIdx.x]; + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); - // across d_head + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; #pragma unroll - for (int j = 0; j < splitH; j++) { - const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; - - state[j] = (state[j] * dA) + (B * x_dt); - - stateC[j * d_state + threadIdx.x] = state[j] * C; + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; } - __syncthreads(); + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); - // parallel accumulation for stateC - // TODO: simplify - { - static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); - static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); - - // reduce until w matches the warp size - // TODO: does this work even when the physical warp size is 64? -#pragma unroll - for (int w = d_state; w > WARP_SIZE; w >>= 1) { - // (assuming there are d_state threads) -#pragma unroll - for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { - // TODO: check for bank conflicts - const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); - stateC[k] += stateC[k + (w >> 1)]; - - } - __syncthreads(); - } - - static_assert(splitH >= d_state / WARP_SIZE); - -#pragma unroll - for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { - float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; - y = warp_reduce_sum(y); - - // store the above accumulations - if (threadIdx.x % WARP_SIZE == 0) { - const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); - y_block[i * stride_y + k] = y; - } - } + if (lane == 0) { + y_warp[i * stride_y] = state_sum; } } // write back the state #pragma unroll - for (int j = 0; j < splitH; j++) { - s_block[j * d_state + threadIdx.x] = state[j]; + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; } } @@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - GGML_ASSERT(d_state % threads == 0); - // NOTE: can be any power of two between 4 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 128><<>>( + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else if (d_state == 256) { // Falcon-H1 - const int threads = 256; - // NOTE: can be any power of two between 8 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 256><<>>( + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa } } else { // Mamba-1 + constexpr int threads = 128; GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); From 68b4d516c305325d31e698c4673b691d2a9d879f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 6 Jan 2026 20:02:30 +0100 Subject: [PATCH 07/19] llama-params-fit: fix last devices with low VRAM (#18494) --- src/llama.cpp | 66 +++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 98fb77084..0162ae8d5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -359,6 +359,11 @@ static void llama_params_fit_impl( // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } }; const size_t ntbo = llama_max_tensor_buft_overrides(); @@ -382,7 +387,7 @@ static void llama_params_fit_impl( size_t itbo = 0; for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + il0 += ngl_per_device[id].n_full(); for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { if (itbo + 1 >= ntbo) { tensor_buft_overrides[itbo].pattern = nullptr; @@ -393,7 +398,7 @@ static void llama_params_fit_impl( + std::to_string(ntbo) + " is insufficient for model"); } tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); itbo++; } il0 += ngl_per_device[id].n_part; @@ -468,20 +473,14 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd - 1; ++id) { - overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); } - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - if (hp_nex > 0) { - for (size_t id = 0; id < nd; id++) { - ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; - } - } // optimize the number of layers per device using the method of false position: // - ngl_per_device has 0 layers for each device, lower bound @@ -512,9 +511,6 @@ static void llama_params_fit_impl( if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - if (hp_nex > 0 && size_t(id) == nd - 1) { - delta--; - } LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); @@ -524,7 +520,8 @@ static void llama_params_fit_impl( std::vector ngl_per_device_test = ngl_per_device; ngl_per_device_test[id].n_layer += step_size; if (hp_nex) { - ngl_per_device_test[id].n_part += step_size; + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full } const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); @@ -573,7 +570,7 @@ static void llama_params_fit_impl( assert(id_dense_start < nd); LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start; id++) { + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; @@ -585,12 +582,8 @@ static void llama_params_fit_impl( std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); - assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); - assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); - uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -606,7 +599,7 @@ static void llama_params_fit_impl( ngl_per_device_test[id].n_layer += n_convert_jd; n_converted_test += n_convert_jd; - if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { break; } } @@ -625,8 +618,8 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); } - delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); } } else { ngl_per_device = ngl_per_device_high; @@ -644,14 +637,19 @@ static void llama_params_fit_impl( ngl_per_device_test[id_dense_start_test].n_part--; ngl_per_device_test[id].n_layer++; ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { id_dense_start_test++; } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", @@ -659,9 +657,10 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", @@ -670,9 +669,10 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", @@ -687,6 +687,14 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } From ccbc84a5374bab7a01f68b129411772ddd8e7c79 Mon Sep 17 00:00:00 2001 From: Tarek Dakhran Date: Tue, 6 Jan 2026 21:00:29 +0100 Subject: [PATCH 08/19] mtmd: mtmd_audio_streaming_istft (#18645) Change is decoupled from https://github.com/ggml-org/llama.cpp/pull/18641. [LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) needs streaming istft for generating output audio. * add streaming ISTFT class (`mtmd_audio_streaming_istft`) with overlap-add for audio reconstruction * replace global audio cache with per-instance cache, the model requires two independent caches, for preprocessing (audio input) and for istft (audio output). * unified templated FFT/IFFT implementation supporting both forward and inverse transforms --- tools/mtmd/mtmd-audio.cpp | 570 ++++++++++++++++++++++++-------------- tools/mtmd/mtmd-audio.h | 73 +++++ 2 files changed, 428 insertions(+), 215 deletions(-) diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e99101184..e8eef035f 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } +} - std::vector data; -}; +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; } - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); } - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; + const int n_fft_bins = n_fft / 2 + 1; - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } - - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // void mtmd_audio_preprocessor_conformer::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_conformer::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { // empty audio if (n_samples == 0) { return false; @@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess( params.use_natural_log = true; params.norm_per_feature = true; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess( output.push_back(std::move(out_full)); return true; } + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index d484c9d03..016c7392e 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; }; struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; }; From 95ea9e0861b28adca740dbc09494f72105c9b92b Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 6 Jan 2026 17:38:29 -0800 Subject: [PATCH 09/19] Hexagon add support for f16/f32 flash attention, scale, set-rows and improve f16/32 matmul (#18611) * hexagon: improve fp16 matmul and add fp32/fp16 flash-attention * hexagon: add support for set-rows fp32 -> fp16 with i32/i64 row-idx * hexagon: add support for SCALE fp32 * hexagon: replace scalar fp32 -> fp16 copy with HVX * hexagon: optimize flash_atten_ext with aligned VTCM buffers and DMA - Implements double-buffered DMA prefetching for K, V, and Mask tensors. - Ensures K and V rows in VTCM are padded to 128 bytes to support aligned HVX operations. - Correctly synchronizes DMA transfers to prevent race conditions. - Uses `FLASH_ATTN_BLOCK_SIZE` of 128 for efficient chunking. * hexagon: use aligned mad_f16 * hexagon: flash_atten more aligned ops * hexagon: optimize scale_f32 hvx helpers * hexagon: unroll fa loops * hexagon: remove unused set-rows log * hexagon: flash_attn_ext add support for DMAing Q - Update `op_flash_attn_ext` to include Q row size in scratchpad allocation. - Pad Q row size to 128 bytes for alignment. - Implement DMA transfer for Q tensor in `flash_attn_ext_f16_thread`. - Update dot product computations to use VTCM-buffered Q data. * hexagon: fix handling of NANs hvx dotproducts * hexagon: cleanup spad allocation in flash-atten * hexagon: improve fp16/fp32 matmul - Introduced `vec_dot_f16_f16` and `vec_dot_f16_f16_rx2` kernels using efficient HVX dot product intrinsics. - Added `quantize_fp32_f16` to copy/convert weights from DDR to VTCM - Updated `op_matmul` to use the optimized path when VTCM capacity allows and broadcasting requirements are compatible. - Implemented fallback logic to the original implementation for complex broadcasting scenarios. * hexagon: fix HVX_ARCH check * hexagon: matmul cleanup and fp16 fixes Use aligned vec_dot_f16 for 2d matmuls and unaligned version for 4d. * hexagon: fix fp16 x fp16 matmuls and some minor refactoring * hexagon: add support for GET_ROWS f32 -> f32 Also optimize SET_ROWS threading a bit when we have just a few rows to process. * hexagon: optimize set-rows threading * hexagon: update adb/run-bench.sh to properly support experimental and verbose options * hexagon: flash_atten use aligned vectors for dot products --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 161 +++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 3 + ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 566 +++++++++++++ ggml/src/ggml-hexagon/htp/get-rows-ops.c | 112 +++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 5 - ggml/src/ggml-hexagon/htp/htp-msg.h | 10 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 28 + ggml/src/ggml-hexagon/htp/hvx-utils.c | 51 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 265 +++++- ggml/src/ggml-hexagon/htp/main.c | 162 +++- ggml/src/ggml-hexagon/htp/matmul-ops.c | 912 ++++++++++++--------- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 168 ++++ ggml/src/ggml-hexagon/htp/softmax-ops.c | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 34 +- scripts/snapdragon/adb/run-bench.sh | 14 +- 15 files changed, 2018 insertions(+), 477 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/flash-attn-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/get-rows-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/set-rows-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 13b96d61f..365a24b49 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_ return true; } +static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * src3 = op->src[3]; + const struct ggml_tensor * src4 = op->src[4]; + const struct ggml_tensor * dst = op; + + // Check for F16 support only as requested + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) { + return false; + } + + if (src3 && src3->type != GGML_TYPE_F16) { // mask + return false; + } + + if (src4 && src4->type != GGML_TYPE_F32) { // sinks + return false; + } + + // For now we support F32 or F16 output as htp backend often converts output on the fly if needed, + // but the op implementation writes to F16 or F32. + // Let's assume dst can be F32 or F16. + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + return opt_experimental; +} + static bool hex_supported_src0_type(ggml_type t) { return t == GGML_TYPE_F32; } @@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if (dst->type != GGML_TYPE_F32) { return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } @@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; // typically the lm-head which would be too large for VTCM } - // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false; if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { return false; } @@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session } break; - case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } - break; - default: return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return true; } +static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F16) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t d->offset = (uint8_t *) t->data - buf->base; d->size = ggml_nbytes(t); + if (!d->size) { + // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty + d->size = 64; + } + switch (type) { case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: // Flush CPU @@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_GET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); @@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SCALE: + req->op = HTP_OP_SCALE; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs return n_bufs; } +static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_FLASH_ATTN_EXT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: @@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_SET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_GET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); break; @@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_rope(sess, op); break; + case GGML_OP_FLASH_ATTN_EXT: + supp = ggml_hexagon_supported_flash_attn_ext(sess, op); + break; + + case GGML_OP_SET_ROWS: + supp = ggml_hexagon_supported_set_rows(sess, op); + break; + + case GGML_OP_GET_ROWS: + supp = ggml_hexagon_supported_get_rows(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2cf8aaa42..6a34a215f 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED softmax-ops.c act-ops.c rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c new file mode 100644 index 000000000..04a7b843c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -0,0 +1,566 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + + hvx_vec_store_u(r, 4, rsum); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(r, 4, rsum); +} + +// MAD: y (F32) += x (F16) * v (float) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { + const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S = hvx_vec_splat_fp16(s); + + uint32_t i = 0; + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); + ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + } + + if (nloe) { + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + + HVX_Vector xs = Q6_V_lo_W(xs_p); + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + } + } +} + +#define FLASH_ATTN_BLOCK_SIZE 128 + +static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; + struct htp_tensor * dst = &octx->dst; + + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + const uint32_t nek0 = k->ne[0]; + const uint32_t nek1 = k->ne[1]; + const uint32_t nek2 = k->ne[2]; + const uint32_t nek3 = k->ne[3]; + + const uint32_t nev0 = v->ne[0]; + const uint32_t nev1 = v->ne[1]; + const uint32_t nev2 = v->ne[2]; + const uint32_t nev3 = v->ne[3]; + + const uint32_t nbq1 = q->nb[1]; + const uint32_t nbq2 = q->nb[2]; + const uint32_t nbq3 = q->nb[3]; + + const uint32_t nbk1 = k->nb[1]; + const uint32_t nbk2 = k->nb[2]; + const uint32_t nbk3 = k->nb[3]; + + const uint32_t nbv1 = v->nb[1]; + const uint32_t nbv2 = v->nb[2]; + const uint32_t nbv3 = v->nb[3]; + + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // total rows in q + const uint32_t nr = neq1*neq2*neq3; + + const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, nr); + + if (ir0 >= ir1) return; + + dma_queue * dma = octx->ctx->dma[ith]; + + const uint32_t DK = nek0; + const uint32_t DV = nev0; + + const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); + const size_t size_q_row_padded = htp_round_up(size_q_row, 128); + + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask + + const size_t size_k_row_padded = htp_round_up(size_k_row, 128); + const size_t size_v_row_padded = htp_round_up(size_v_row, 128); + + const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator + uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; + uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith; + uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith; + uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; + uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); + + const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + + const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + + // Fetch Q row + const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + // Clear accumulator + float * VKQ32 = (float *) spad_a; + memset(VKQ32, 0, DV * sizeof(float)); + + const __fp16 * mp_base = NULL; + if (mask) { + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); + } + + const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + // Prefetch first two blocks + for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); + uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + // Mask is 1D contiguous for this row + dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + } + } + + const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + + for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // Wait for DMA + uint8_t * k_base = dma_queue_pop(dma).dst; // K + uint8_t * v_base = dma_queue_pop(dma).dst; // V + __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + + // Inner loop processing the block from VTCM + uint32_t ic = 0; + + // Process in blocks of 32 (VLEN_FP32) + for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + // 1. Compute scores + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } + } + + HVX_Vector scores = *(HVX_Vector *) scores_arr; + + // 2. Softcap + if (logit_softcap != 0.0f) { + scores = hvx_vec_tanh_fp32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 3. Mask + if (mask) { + const __fp16 * mp = m_base + ic; + HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; + + HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); + + HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); + + HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); + scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 4. Online Softmax Update + HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); + float m_block = hvx_vec_get_fp32(v_max); + + float M_old = M; + float M_new = (m_block > M) ? m_block : M; + M = M_new; + + float ms = expf(M_old - M_new); + + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + S = S * ms; + + HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); + float p_sum = hvx_vec_get_fp32(p_sum_vec); + S += p_sum; + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } + } + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * size_k_row_padded; + + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } + + if (logit_softcap != 0.0f) { + s_val = logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + float ms = 1.0f; + float vs = 1.0f; + + if (s_val > M) { + M = s_val; + ms = expf(Mold - M); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s_val - M); + } + + const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + + S = S * ms + vs; + } + + // Issue DMA for next+1 block (if exists) + if (ib + 2 < n_blocks) { + const uint32_t next_ib = ib + 2; + const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); + dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + } + } + } + + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv); + + // Store result + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // dst is permuted + uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + + if (dst->type == HTP_TYPE_F32) { + hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } else if (dst->type == HTP_TYPE_F16) { + hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } + } +} + +static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + flash_attn_ext_f16_thread(octx, i, n); +} + +int op_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; + struct htp_tensor * dst = &octx->dst; + + // Check support + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || + k->type != HTP_TYPE_F16 || + v->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + octx->src0_div1 = init_fastdiv_values(q->ne[1]); + + octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + + if (mask) { + octx->src3_div2 = init_fastdiv_values(mask->ne[2]); + octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = size_q_row_padded * 1; // single row for now + size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + + octx->src0_spad.size_per_thread = size_q_block * 1; + octx->src1_spad.size_per_thread = size_k_block * 2; + octx->src2_spad.size_per_thread = size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->dst_spad.size_per_thread = size_vkq_acc; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads; + octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size; + + if (octx->ctx->vtcm_size < total_spad) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c new file mode 100644 index 000000000..54321421e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define get_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t nr = ne10 * ne11 * ne12; + +static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + get_rows_preamble; + + // parallelize by src1 elements (which correspond to dst rows) + const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t rem = i - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + // invalid index, skip for now to avoid crash + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + + return HTP_STATUS_OK; +} + +static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 5c3d217f1..4bd0ea7a3 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -11,11 +11,6 @@ #define HTP_MAX_NTHREADS 10 -// FIXME: move these into matmul-ops -#define HTP_SPAD_SRC0_NROWS 16 -#define HTP_SPAD_SRC1_NROWS 16 -#define HTP_SPAD_DST_NROWS 2 - // Main context for htp DSP backend struct htp_context { dspqueue_t queue; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index a61652304..846d06178 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -36,6 +36,8 @@ enum htp_data_type { HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -57,6 +59,10 @@ enum htp_op { HTP_OP_SOFTMAX = 11, HTP_OP_ADD_ID = 12, HTP_OP_ROPE = 13, + HTP_OP_FLASH_ATTN_EXT = 14, + HTP_OP_SET_ROWS = 15, + HTP_OP_SCALE = 16, + HTP_OP_GET_ROWS = 17, INVALID }; @@ -137,6 +143,8 @@ struct htp_general_req { struct htp_tensor src0; // Input0 tensor struct htp_tensor src1; // Input1 tensor struct htp_tensor src2; // Input2 tensor + struct htp_tensor src3; // Input3 tensor + struct htp_tensor src4; // Input4 tensor struct htp_tensor dst; // Output tensor // should be multiple of 64 bytes (cacheline) @@ -152,6 +160,6 @@ struct htp_general_rsp { }; #define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 4 +#define HTP_MAX_PACKET_BUFFERS 8 #endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index e87657436..7c828ae63 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -13,6 +13,7 @@ struct htp_spad { uint8_t * data; + size_t stride; size_t size; size_t size_per_thread; }; @@ -26,11 +27,14 @@ struct htp_ops_context { struct htp_tensor src0; struct htp_tensor src1; struct htp_tensor src2; + struct htp_tensor src3; + struct htp_tensor src4; struct htp_tensor dst; struct htp_spad src0_spad; struct htp_spad src1_spad; struct htp_spad src2_spad; + struct htp_spad src3_spad; struct htp_spad dst_spad; worker_pool_context_t * wpool; // worker pool @@ -49,6 +53,27 @@ struct htp_ops_context { struct fastdiv_values src1_div3; // fastdiv values for ne3 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + struct fastdiv_values src3_div1; // fastdiv values for ne1 + struct fastdiv_values src3_div2; // fastdiv values for ne2 + struct fastdiv_values src3_div3; // fastdiv values for ne3 + struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 + struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 + struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 + struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 + + struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 + struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 + + struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 + struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + uint32_t flags; }; @@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c index f9e02ab67..29d73b862 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); } -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in1 = (HVX_Vector *) src; - HVX_Vector * vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src, hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); } } + + diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index d2d5d2363..22876e6db 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) } #endif -static inline HVX_Vector hvx_vec_splat_fp32(float i) { +static inline HVX_Vector hvx_vec_splat_fp32(float v) { union { - float f; - int32_t i; - } fp32 = { .f = i }; + float f; + uint32_t i; + } fp32 = { .f = v }; return Q6_V_vsplat_R(fp32.i); } +static inline HVX_Vector hvx_vec_splat_fp16(float v) { + union { + __fp16 f; + uint16_t i; + } fp16 = { .f = v }; + + return Q6_Vh_vsplat_R(fp16.i); +} + static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) addr); @@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest } } +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { HVX_Vector * restrict vdst = (HVX_Vector *) dst; @@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3 return right_off <= chunk_size; } - - static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { HVX_VectorAlias u = { .v = v }; @@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { } static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HTP_ARCH__ > 75 +#if __HVX_ARCH__ > 75 return Q6_Vsf_vfneg_Vsf(v); #else // neg by setting the fp32 sign bit HVX_Vector mask = Q6_V_vsplat_R(0x80000000); return Q6_V_vxor_VV(v, mask); -#endif // __HTP_ARCH__ > 75 +#endif // __HVX_ARCH__ > 75 } // ==================================================== @@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } +static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_fp32(2.0f); + HVX_Vector one = hvx_vec_splat_fp32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; @@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr } } +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } +} + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } +} float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); void hvx_mul_f32(const uint8_t * restrict src0, @@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0, uint8_t * restrict dst, const int num_elems); void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale); void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fb5508a56..24b3e90e4 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_get_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_matmul_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - int write_idx = (n_bufs == 4) ? 3 : 2; + int write_idx = n_bufs - 1; // We had written to the output buffer, we'd also need to flush it rsp_bufs[0].fd = bufs[write_idx].fd; @@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_set_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_flash_attn_ext_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + // Setup Op context + struct htp_ops_context octx; + memset(&octx, 0, sizeof(octx)); + + octx.ctx = ctx; + octx.n_threads = ctx->n_threads; + + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.src3 = req->src3; + octx.src4 = req->src4; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + + int last_buf = 3; + + if (octx.src3.ne[0]) { + octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + } + + if (octx.src4.ne[0]) { + octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + } + + octx.dst.data = (uint32_t) bufs[last_buf].ptr; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_flash_attn_ext(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + + struct dspqueue_buffer rsp_buf = bufs[last_buf]; + rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); +} + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { break; case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_rope_req(ctx, &req, bufs, n_bufs); break; + case HTP_OP_FLASH_ATTN_EXT: + if (!(n_bufs >= 4 && n_bufs <= 6)) { + FARF(ERROR, "Bad flash-attn-ext-req buffer list"); + continue; + } + proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_SET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad set-rows-req buffer list"); + continue; + } + proc_set_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_GET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad get-rows-req buffer list"); + continue; + } + proc_get_rows_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index f14523d48..9bb39db9f 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -26,14 +26,14 @@ #include "hvx-utils.h" #include "ops-utils.h" +#define MM_SPAD_SRC0_NROWS 16 +#define MM_SPAD_SRC1_NROWS 16 +#define MM_SPAD_DST_NROWS 2 + struct htp_matmul_type { const char * type; void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy); + void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); }; typedef struct { @@ -907,145 +907,174 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); } -#if 1 -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - if (0) { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; +static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - for (uint32_t i = 0; i < n; i++) { - rsum += (float)vx[i] * vy[i]; - } - *s = rsum; - return; - } + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; - const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y; + HVX_Vector rsum = Q6_V_vsplat_R(0); - uint32_t nv0 = n / 64; // num full fp16 hvx vectors - uint32_t nv1 = n % 64; // leftover elements - - // for some reason we need volatile here so that the compiler doesn't try anything funky - volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - float r_sum_scalar = 0.0f; uint32_t i = 0; - for (i = 0; i < nv0; i++) { - HVX_VectorPair yp = vy[i]; - - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - //NOTE: need volatile here to prevent compiler optimization - // Seem compiler cannot guarantee read-after-write?? - volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - - HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - if (nv1) { - // HVX_VectorPair yp = vy[i]; + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - // HVX_Vector x = vx[i]; - // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - // if (nv1 >= 32) { - // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - // nv1 -= 32; - // } - - // rsum = hvx_vec_qf32_reduce_sum(rsum); - - // if (nv1) { - // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); - // } - - //process the remainder using scalar loop - rsum = hvx_vec_qf32_reduce_sum(rsum); - const __fp16 * restrict sx = (const __fp16 * restrict) x; - const float * restrict sy = (const float * restrict) y; - - for (uint32_t i = nv0 * 64; i < n; i++) { - r_sum_scalar += (float) sx[i] * sy[i]; - } - - // hvx_vec_dump_fp16("X", x); - // hvx_vec_dump_fp16("Y", y); - // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum)); - // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum)); - } else { - rsum = hvx_vec_qf32_reduce_sum(rsum); + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; - -# ifdef HTP_DEBUG - { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; - - for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * vy[i]; - } - - float diff = fabs(*s - rsum); - if (diff > 0.001) { - FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s); - // htp_dump_f16("x", vx, n); - // htp_dump_f32("y", vy, n); - } - } -# endif + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); } -#else -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const uint32_t fk = 64; - const uint32_t nb = n / fk; - assert(n % fk == 0); - assert(nb % 4 == 0); +static void vec_dot_f16_f16_aa_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; + const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - const uint32_t x_blk_size = 2 * fk; // fp16 - const uint32_t y_blk_size = 4 * fk; // fp32 + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; - // Row sum (qf32) HVX_Vector rsum0 = Q6_V_vsplat_R(0); HVX_Vector rsum1 = Q6_V_vsplat_R(0); - HVX_Vector rsum2 = Q6_V_vsplat_R(0); - HVX_Vector rsum3 = Q6_V_vsplat_R(0); - for (uint32_t i = 0; i < nb; i += 4) { - HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size)); - HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size)); + uint32_t i = 0; - HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]); - HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]); - HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]); - HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = y[i]; + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1))); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2))); - rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3))); + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - // Reduce and convert into fp32 - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3); - HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2)); - hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum)); -} -#endif + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); -#define htp_matmul_preamble \ + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_UVector * restrict x = (const HVX_UVector *) vx; + const HVX_UVector * restrict y = (const HVX_UVector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +#define htp_matmul_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict src2 = &octx->src2; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -1056,6 +1085,11 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ const uint32_t ne0 = dst->ne[0]; \ const uint32_t ne1 = dst->ne[1]; \ const uint32_t ne2 = dst->ne[2]; \ @@ -1076,18 +1110,94 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -// q8x4 src1 tensor is already in VTCM spad -static void matmul(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +#define htp_matmul_preamble \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +// *** matmul with support for 4d tensors and full broadcasting + +static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { + const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1104,9 +1214,10 @@ static void matmul(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1124,11 +1235,11 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows @@ -1137,17 +1248,17 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1155,13 +1266,13 @@ static void matmul(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const int is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); } @@ -1176,17 +1287,7 @@ static void matmul(struct htp_matmul_type * mt, } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1202,9 +1303,10 @@ static void matvec(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1226,24 +1328,24 @@ static void matvec(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1251,8 +1353,8 @@ static void matvec(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } @@ -1274,22 +1376,13 @@ struct mmid_row_mapping { uint32_t i2; }; -// q8x4 src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict ids, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1340,7 +1433,7 @@ static void matmul_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1365,8 +1458,8 @@ static void matmul_id(struct htp_matmul_type * mt, } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1404,22 +1497,13 @@ static void matmul_id(struct htp_matmul_type * mt, dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4 src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict src2, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1464,7 +1548,7 @@ static void matvec_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1477,8 +1561,8 @@ static void matvec_id(struct htp_matmul_type * mt, mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1504,106 +1588,6 @@ static void matvec_id(struct htp_matmul_type * mt, dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** matmul in fp16 - -static void matmul_f16_f32(struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_matmul_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const uint32_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const uint32_t nr1 = ne1 * ne2 * ne3; - - // distribute the thread work across the inner or outer loop based on which one is larger - uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - // The number of elements in each chunk - const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - uint32_t current_chunk = ith; - - const uint32_t ith0 = current_chunk % nchunk0; - const uint32_t ith1 = current_chunk / nchunk0; - - const uint32_t ir0_start = dr0 * ith0; - const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); - - const uint32_t ir1_start = dr1 * ith1; - const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); - - // broadcast factors - const uint32_t r2 = ne12 / ne02; - const uint32_t r3 = ne13 / ne03; - - // no work for this thread - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - // block-tiling attempt - const uint32_t blck_0 = 64; - const uint32_t blck_1 = 64; - - __attribute__((aligned(128))) float tmp[64]; - - for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = (ir1 / (ne12 * ne1)); - const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const uint32_t i03 = i13 / r3; - const uint32_t i02 = i12 / r2; - - const uint32_t i1 = i11; - const uint32_t i2 = i12; - const uint32_t i3 = i13; - - const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); - const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); - float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { - // Use nb01 stride for non-contiguous src0 support - const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); - } - - hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); - } - } - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - // *** dynamic quant static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { @@ -1780,20 +1764,14 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u for (uint32_t i = 0; i < nb; i++) { #if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #else #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" #endif @@ -1848,14 +1826,95 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); } -// ** matmul callbacks for worker_pool +static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} -static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} + +// ** matmul/matvec callbacks for worker_pool + +static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1863,11 +1922,10 @@ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1875,11 +1933,10 @@ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1887,11 +1944,10 @@ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1899,11 +1955,10 @@ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1911,11 +1966,10 @@ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1923,14 +1977,49 @@ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matvec_2d(&mt, octx, n, i); +} + +static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matmul_2d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f32"; + mt.vec_dot = vec_dot_f16_f32_uu; + + matmul_4d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_uu; + + matmul_4d(&mt, octx, n, i); } // ** matmul-id callbacks for worker_pool @@ -1943,8 +2032,7 @@ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1955,8 +2043,7 @@ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1967,8 +2054,7 @@ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1979,8 +2065,7 @@ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1991,8 +2076,7 @@ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -2003,18 +2087,17 @@ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } // ** main matmul entry point -int op_matmul(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +} - htp_matmul_preamble; +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; const char * op_type; @@ -2038,9 +2121,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q4x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2048,8 +2131,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2067,9 +2150,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q8x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q8x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q8x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2077,8 +2160,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2096,9 +2179,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "mxfp4x4x2-f32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2106,8 +2189,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2122,20 +2205,69 @@ int op_matmul(struct htp_ops_context * octx) { break; case HTP_TYPE_F16: - op_type = "f16-f32"; - quant_job_func = NULL; // htp_quantize_f32_f16; - matmul_job_func = htp_matmul_f16_f32; + { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - // For all tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256); + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - need_quant = false; + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + op_type = "f16-f16"; + quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_2d_f16_f16; + } else { + matmul_job_func = htp_matvec_2d_f16_f16; + } + + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + op_type = "f16-f32"; + matmul_job_func = htp_matmul_4d_f16_f32; + } else { + op_type = "f16-f16"; + matmul_job_func = htp_matmul_4d_f16_f16; + } + + src1_row_size = nb11; // original row size in DDR + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } break; default: @@ -2166,6 +2298,9 @@ int op_matmul(struct htp_ops_context * octx) { octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + if (need_quant) { // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -2185,12 +2320,9 @@ int op_matmul(struct htp_ops_context * octx) { // ** main matmul-id entry point int op_matmul_id(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * ids = &octx->src2; - struct htp_tensor * dst = &octx->dst; + htp_matmul_tensors_preamble; - htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; const char * op_type; @@ -2228,8 +2360,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2257,8 +2389,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2286,8 +2418,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c new file mode 100644 index 000000000..bdd64fcc8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -0,0 +1,168 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t ne1 = octx->dst.ne[1]; \ + \ + const uint32_t nr = ne01; + +static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + // copy row + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); +} + +static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_set_rows(struct htp_ops_context * octx) { + set_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->set_rows_div_ne12 = init_fastdiv_values(ne12); + octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + switch(octx->dst.type) { + case HTP_TYPE_F32: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + break; + case HTP_TYPE_F16: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 5bf0cbf79..80d249a22 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, (const uint8_t *) mp_f32, slope); } else { - hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale); + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); if (mp_f32) { if (softmax_ctx->use_f16) { for (int i = 0; i < ne00; ++i) { @@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum); + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); } } } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index bb7557b02..8ed1e5b66 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void scale_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float scale = 0.f; + float bias = 0.f; + memcpy(&scale, &op_params[0], sizeof(float)); + memcpy(&bias, &op_params[1], sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + } + + hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + static void rms_norm_htp_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src, const float mean = sum / row_elems; const float scale = 1.0f / sqrtf(mean + epsilon); - hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale); + hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); } } } @@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_RMS_NORM: rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SCALE: + scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "rmsnorm-f32"; break; + case HTP_OP_SCALE: + unary_op_func = unary_job_dispatcher_f32; + op_type = "scale-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index b2e651e74..1a7d8c9fd 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf" device="HTP0" [ "$D" != "" ] && device="$D" -verbose="" -[ "$V" != "" ] && verbose="$V" +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" opmask= [ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" @@ -34,7 +40,7 @@ adb $adbserial shell " \ cd $basedir; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ + $ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --batch-size 128 -ngl 99 $@ \ + --batch-size 128 -ngl 99 $cli_opts $@ \ " From 193ee38a1bc2aed807823112cf030edcae186892 Mon Sep 17 00:00:00 2001 From: Raul Torres <138264735+rauletorresc@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:01:25 +0000 Subject: [PATCH 10/19] CANN: Rename `get_env` to `get_env_as_lowercase` (#18624) --- ggml/src/ggml-cann/ggml-cann.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index acf88db9b..162d238ae 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() { } /** - * @brief Get the value of the specified environment variable (name). + * @brief Get the value of the specified environment variable (name) as lowercase. * if not empty, return a std::string object */ -std::optional get_env(const std::string & name) { +std::optional get_env_as_lowercase(const std::string & name) { const char * val = std::getenv(name.c_str()); if (!val) { return std::nullopt; @@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf_prio(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @return A unique pointer to the created CANN pool. */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(int device) { - std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); + std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or(""); if (mem_pool_type == "prio") { GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); @@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, // Why aclrtSynchronizeDevice? // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { @@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t int64_t ne0 = tensor->ne[0]; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); // last line must bigger than 32, because every single op deal at // least 32 bytes. @@ -2136,7 +2136,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. - static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or("")); + static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); if (!use_cann_graph || cann_graph_capture_required) { for (int i = 0; i < cgraph->n_nodes; i++) { @@ -2201,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { // Do not use acl_graph for prefill. for (int i = 0; i < cgraph->n_nodes; i++) { From 3333951d86c0a67bac318cc6b0a924b673f6e4e5 Mon Sep 17 00:00:00 2001 From: hipudding Date: Wed, 7 Jan 2026 16:11:31 +0800 Subject: [PATCH 11/19] CANN: Fix rename for get_env (#18652) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In #18624, get_env in ggml-cann was renamed to get_env_as_lowercase to accurately reflect the function’s behavior and reduce the chance of misuse. However, the update missed renaming call sites in other files. This commit fixes that oversight. --- ggml/src/ggml-cann/aclnn_ops.cpp | 2 +- ggml/src/ggml-cann/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 50b6bd00e..6b718e01c 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1963,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * acl_tensor_ptr acl_weight_tensor; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index e9a21e1b0..6895349b2 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info(); void ggml_cann_set_device(int32_t device); int32_t ggml_cann_get_device(); -std::optional get_env(const std::string & name); +std::optional get_env_as_lowercase(const std::string & name); bool parse_bool(const std::string & value); int parse_integer(const std::string & value); From ffba4f29e6a9ed7165ea6b94150856c5b49925cb Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 7 Jan 2026 10:42:19 +0100 Subject: [PATCH 12/19] examples : add debug utility/example (#18464) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * examples : add debug utility/example This commit introduces a new example named llama-debug which is a utility that is intended to be used to assist with developing/debugging a converted model. The motivation for this utilitiy is to assist in model conversion work to verify that the model produces the expected outputs. It is intended to replace logits.cpp in examples/model-conversion. Example usage: ```console ./build/bin/llama-debug \ -m models/Qwen2.5-0.5B-Instruct.gguf \ --prompt "Hello, my name is" \ --save-logits ... Model add_bos: false Input prompt: "Hello, my name is" Token ids (5): Hello(9707) ,(11) my(847) name(829) is(374) Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.bin Data saved to data/llamacpp-Qwen2.5-0.5B-Instruct.txt Prompt saved to data/llamacpp-Qwen2.5-0.5B-Instruct-prompt.txt Tokens saved to data/llamacpp-Qwen2.5-0.5B-Instruct-tokens.bin ``` For more details about the options available for this example, please refer to examples/debug/README.md. * throw runtime error instead of logging error * remove params.warmup and enable the warmup/nowarmup option * model-conversion : remove logits.cpp This commit removes logits.cpp in favor of using llama-debug for generating logits and embeddings. * examples : remove model-conversion directory This was missed in the previous commit. * model-conversion : add support for saving prompt and token ids This commit add support for storing the prompt and the token ids for the prompt when running the original models. The motivation for this is that this will allow us to compare the prompt and the tokens generated for the prompt when verifing the converted model. Currently it is possible that even if the same prompt is used that the tokens generated are different if there is a difference in the tokenization between the original and converted model which would currently go unnoticed (the verification will most likely fail but it might not be obvious why). * squash! model-conversion : add support for saving prompt and token ids fix pyright errors. * model-conversion : add compare_tokens utility This commit adds a script to compare token outputs between original and converted models. Example usage: ```console (venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 Comparing tokens between: Original : pytorch-gemma-3-270m-it (6 tokens) Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens) ✅ All 6 tokens match! ``` And there is a verbose flag that will also print out the prompts: ```console (venv) $ ./scripts/utils/compare_tokens.py pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 -v Original model prompt (pytorch-gemma-3-270m-it): prompt: Hello, my name is n_tokens: 6 token ids: 2, 9259, 236764, 1041, 1463, 563 Converted model prompt (llamacpp-gemma-3-270m-it-bf16): prompt: Hello, my name is n_tokens: 6 token ids: 2, 9259, 236764, 1041, 1463, 563 Comparing tokens between: Original : pytorch-gemma-3-270m-it (6 tokens) Converted: llamacpp-gemma-3-270m-it-bf16 (6 tokens) ✅ All 6 tokens match! ``` * model-conversion : add token comparison to verifiction scripts This commit add the calling of the compare_tokens function in compare-logits.py and semantic_check.py to ensure that the token ids that the tokenizers procoduce are the same before proceeding with verifying the logits/embeddings. Placing them in the existing scripts instead calling them separately ensures that the token comparison is always done prior to the logit/embedding verifications. Follow up commit/pr could refactor the causal logits verification into a single script instead of the two that exist now. This would reduce the code and make it consistent with the embeddings verficiation which only has a single script. * debug : use llama_model_n_embd_out This commit updates the debug example to use the new function llama_model_n_embd_out instead of llama_model_n_embd. The motivation for this change is to support late interation retriever models, like LFM2-ColBert-350M, where the output embeddings are down projected to a lower dimension. * debug : add print_usage function This commit adds a print_usage function that is passed to the common_params_parse. The motivation for this is that this enables a specific usage message which will be printed after all the options, for example: ```console example usage: Print tensors: ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --verbose The tensors to be printed can be filtered with --tensor-filter option. Save logits/embeddings: ./build/bin/llama-debug -m model.gguf -p "Hello my name is" --save-logits Add --embedding to save embeddings ``` --- common/arg.cpp | 29 +- common/common.h | 6 + examples/CMakeLists.txt | 2 +- .../CMakeLists.txt | 4 +- examples/debug/README.md | 54 +++ examples/debug/debug.cpp | 421 ++++++++++++++++++ examples/model-conversion/logits.cpp | 268 ----------- .../scripts/causal/compare-logits.py | 9 +- .../causal/run-casual-gen-embeddings-org.py | 2 +- .../run-converted-model-embeddings-logits.sh | 4 +- .../scripts/causal/run-converted-model.sh | 4 +- .../scripts/causal/run-org-model.py | 20 +- .../scripts/embedding/run-converted-model.sh | 7 +- .../scripts/embedding/run-original-model.py | 25 +- .../model-conversion/scripts/utils/common.py | 95 ++++ .../scripts/utils/compare_tokens.py | 76 ++++ .../scripts/utils/semantic_check.py | 18 + 17 files changed, 725 insertions(+), 319 deletions(-) rename examples/{model-conversion => debug}/CMakeLists.txt (73%) create mode 100644 examples/debug/README.md create mode 100644 examples/debug/debug.cpp delete mode 100644 examples/model-conversion/logits.cpp create mode 100755 examples/model-conversion/scripts/utils/compare_tokens.py diff --git a/common/arg.cpp b/common/arg.cpp index c3610d262..a67a26e2d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1445,7 +1445,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.warmup = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -1761,7 +1761,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); add_opt(common_arg( {"--attention"}, "{causal,non-causal}", "attention type for embeddings, use model default if unspecified", @@ -2609,7 +2609,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", @@ -2687,7 +2687,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.embedding = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(common_arg( {"--rerank", "--reranking"}, string_format("enable reranking endpoint on server (default: %s)", "disabled"), @@ -3378,6 +3378,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({ LLAMA_EXAMPLE_FINETUNE })); + add_opt(common_arg( + {"--save-logits"}, + string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"), + [](common_params & params) { + params.save_logits = true; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--logits-output-dir"}, "PATH", + string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()), + [](common_params & params, const std::string & value) { + params.logits_output_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--tensor-filter"}, "REGEX", + "filter tensor names for debug output (regex pattern, can be specified multiple times)", + [](common_params & params, const std::string & value) { + params.tensor_filter.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); // presets add_opt(common_arg( diff --git a/common/common.h b/common/common.h index daea6ded5..d6fd0d37a 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -372,6 +373,11 @@ struct common_params { std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 91797cf78..a29dc707c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,6 +15,7 @@ llama_add_compile_flags() if (EMSCRIPTEN) else() add_subdirectory(batched) + add_subdirectory(debug) add_subdirectory(embedding) add_subdirectory(eval-callback) @@ -34,7 +35,6 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) - add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/model-conversion/CMakeLists.txt b/examples/debug/CMakeLists.txt similarity index 73% rename from examples/model-conversion/CMakeLists.txt rename to examples/debug/CMakeLists.txt index fc1746ce4..34593072b 100644 --- a/examples/model-conversion/CMakeLists.txt +++ b/examples/debug/CMakeLists.txt @@ -1,5 +1,5 @@ -set(TARGET llama-logits) -add_executable(${TARGET} logits.cpp) +set(TARGET llama-debug) +add_executable(${TARGET} debug.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/debug/README.md b/examples/debug/README.md new file mode 100644 index 000000000..28e00c934 --- /dev/null +++ b/examples/debug/README.md @@ -0,0 +1,54 @@ +# llama.cpp/examples/debug + +This is a utility intended to help debug a model by registering a callback that +logs GGML operations and tensor data. It can also store the generated logits or +embeddings as well as the prompt and token ids for comparision with the original +model. + +### Usage + +```shell +llama-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --save-logits \ + --verbose +``` +The tensor data is logged as debug and required the --verbose flag. The reason +for this is that while useful for a model with many layers there can be a lot of +output. You can filter the tensor names using the `--tensor-filter` option. + +A recommended approach is to first run without `--verbose` and see if the +generated logits/embeddings are close to the original model. If they are not, +then it might be required to inspect tensor by tensor and in that case it is +useful to enable the `--verbose` flag along with `--tensor-filter` to focus on +specific tensors. + +### Options +This example supports all standard `llama.cpp` options and also accepts the +following options: +```console +$ llama-debug --help +... + +----- example-specific params ----- + +--save-logits save final logits to files for verification (default: false) +--logits-output-dir PATH directory for saving logits output files (default: data) +--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times) +``` + +### Output Files + +When `--save-logits` is enabled, the following files are created in the output +directory: + +* `llamacpp-[-embeddings].bin` - Binary output (logits or embeddings) +* `llamacpp-[-embeddings].txt` - Text output (logits or embeddings, one per line) +* `llamacpp-[-embeddings]-prompt.txt` - Prompt text and token IDs +* `llamacpp-[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison + +These files can be compared against the original model's output to verify the +converted model. diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp new file mode 100644 index 000000000..9bc5d0abf --- /dev/null +++ b/examples/debug/debug.cpp @@ -0,0 +1,421 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + const std::string usage_template = R"( + example usage: + + Print tensors: + + {prog} -m model.gguf -p "Hello my name is" --verbose + + The tensors to be printed can be filtered with --tensor-filter option. + + Save logits/embeddings: + + {prog} -m model.gguf -p "Hello my name is" --save-logits + + Add --embedding to save embeddings)" "\n"; + + // Fix the source code indentation above that is introduced by the raw string literal. + std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n"); + usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]); + LOG("%s\n", usage.c_str()); +} + +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data); + +struct callback_data { + std::vector data; + std::vector tensor_filters; + + callback_data() = default; + + callback_data(common_params & params, const std::vector & filter_patterns) { + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + params.cb_eval = ggml_debug; + params.cb_eval_user_data = this; + } +}; + +struct output_data { + float * data_ptr = nullptr; + int data_size = 0; + std::string type_suffix; + std::vector storage; + std::string prompt; + std::vector tokens; + + output_data(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + tokens = common_tokenize(ctx, params.prompt, add_bos); + prompt = params.prompt; + + if (params.embedding) { + const int n_embd = llama_model_n_embd_out(model); + const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE; + const int n_embd_count = pooling_enabled ? 1 : tokens.size(); + const int n_embeddings = n_embd * n_embd_count; + + float * embeddings; + if (pooling_enabled) { + embeddings = llama_get_embeddings_seq(ctx, 0); + storage.resize(n_embeddings); + common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize); + embeddings = storage.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } + + data_ptr = embeddings; + data_size = n_embeddings; + type_suffix = "-embeddings"; + } else { + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + data_ptr = const_cast(logits); + data_size = n_logits; + type_suffix = ""; + } + } +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +static float ggml_get_float_value(const uint8_t * data, ggml_type type, + const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + switch (type) { + case GGML_TYPE_F16: + return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); + case GGML_TYPE_F32: + return *(const float *) &data[i]; + case GGML_TYPE_I64: + return (float) *(const int64_t *) &data[i]; + case GGML_TYPE_I32: + return (float) *(const int32_t *) &data[i]; + case GGML_TYPE_I16: + return (float) *(const int16_t *) &data[i]; + case GGML_TYPE_I8: + return (float) *(const int8_t *) &data[i]; + case GGML_TYPE_BF16: + return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); + default: + GGML_ABORT("fatal error"); + } +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + float sum_sq = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + sum_sq += v * v; + } + } + } + } + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LOG_DBG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LOG_DBG(" ..., \n"); + i2 = ne[2] - n; + } + LOG_DBG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LOG_DBG(" ..., \n"); + i1 = ne[1] - n; + } + LOG_DBG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LOG_DBG("..., "); + i0 = ne[0] - n; + } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + LOG_DBG("%12.4f", v); + if (i0 < ne[0] - 1) { + LOG_DBG(", "); + } + } + LOG_DBG("],\n"); + } + LOG_DBG(" ],\n"); + } + LOG_DBG(" ]\n"); + LOG_DBG(" sum = %f\n", sum); + LOG_DBG(" sum_sq = %f\n", sum_sq); + } + + if (std::isnan(sum)) { + LOG_ERR("encountered NaN - aborting\n"); + exit(0); + } +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * see ggml_backend_sched_eval_callback + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + bool matches_filter = cb_data->tensor_filters.empty(); + + if (!matches_filter) { + for (const auto & filter : cb_data->tensor_filters) { + if (std::regex_search(t->name, filter)) { + matches_filter = true; + break; + } + } + } + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + if (matches_filter) { + LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, + ggml_type_name(t->type), + ggml_op_desc(t), + src0->name, + ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + } + + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + if (!ggml_is_quantized(t->type) && matches_filter) { + uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + } + + return true; +} + + +static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) { + std::filesystem::create_directory(output_dir); + auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix); + + // Save logits/embeddings to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open binary output file: " + filepath.string()); + } + file.write(reinterpret_cast(output.data_ptr), output.data_size * sizeof(float)); + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save logits/embeddings to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open text output file: " + filepath.string()); + } + for (int i = 0; i < output.data_size; i++) { + file << i << ": " << output.data_ptr[i] << '\n'; + } + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save prompt and tokens to text file. + { + std::filesystem::path filepath{base_path.string() + "-prompt.txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open prompt output file: " + filepath.string()); + } + + file << "prompt: " << output.prompt << '\n'; + file << "n_tokens: " << output.tokens.size() << '\n'; + + file << "token ids: "; + for (size_t i = 0; i < output.tokens.size(); i++) { + file << output.tokens[i]; + if (i + 1 < output.tokens.size()) { + file << ", "; + } + } + file << '\n'; + LOG("Prompt saved to %s\n", filepath.c_str()); + } + + // Save token ids to binary file. + { + std::filesystem::path filepath{base_path.string() + "-tokens.bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open tokens binary file: " + filepath.string()); + } + file.write(reinterpret_cast(output.tokens.data()), output.tokens.size() * sizeof(llama_token)); + LOG("Tokens saved to %s\n", filepath.c_str()); + } + +} + +static void print_tokenized_prompt(llama_context * ctx, const std::vector & tokens, const std::string & prompt) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false"); + LOG("Input prompt: \"%s\"\n", prompt.c_str()); + LOG("Token ids (%zu):\n", tokens.size()); + + for (auto id : tokens) { + std::string piece(128, '\0'); + int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true); + if (n < 0) { + LOG_ERR("failed to convert token %d to piece\n", id); + continue; + } + piece.resize(n); + LOG("%s(%d) ", piece.c_str(), id); + } + LOG("\n"); +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + print_tokenized_prompt(ctx, tokens, params.prompt); + + if (params.save_logits) { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + callback_data cb_data(params, params.tensor_filter); + + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + if (!run(ctx, params)) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp deleted file mode 100644 index f71f772ab..000000000 --- a/examples/model-conversion/logits.cpp +++ /dev/null @@ -1,268 +0,0 @@ -#include "llama.h" -#include "common.h" - - -#include -#include -#include -#include -#include -#include - -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); - printf("\n"); - printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); - printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); - printf("\n"); -} - -int main(int argc, char ** argv) { - std::string model_path; - std::string prompt = "Hello, my name is"; - int ngl = 0; - bool embedding_mode = false; - bool pooling_enabled = false; - int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - { - int i = 1; - for (; i < argc; i++) { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - try { - ngl = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-embd-mode") == 0) { - embedding_mode = true; - } else if (strcmp(argv[i], "-pooling") == 0) { - pooling_enabled = true; - } else if (strcmp(argv[i], "-embd-norm") == 0) { - if (i + 1 < argc) { - try { - embd_norm = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else { - // prompt starts here - break; - } - } - - if (model_path.empty()) { - print_usage(argc, argv); - return 1; - } - - if (i < argc) { - prompt = argv[i++]; - for (; i < argc; i++) { - prompt += " "; - prompt += argv[i]; - } - } - } - - ggml_backend_load_all(); - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - // Extract basename from model_path - const char * basename = strrchr(model_path.c_str(), '/'); - basename = (basename == NULL) ? model_path.c_str() : basename + 1; - - char model_name[256]; - strncpy(model_name, basename, 255); - model_name[255] = '\0'; - - char * dot = strrchr(model_name, '.'); - if (dot != NULL && strcmp(dot, ".gguf") == 0) { - *dot = '\0'; - } - printf("Model name: %s\n", model_name); - - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); - - std::vector prompt_tokens(n_prompt); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); - return 1; - } - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_prompt; - ctx_params.n_batch = n_prompt; - ctx_params.no_perf = false; - if (embedding_mode) { - ctx_params.embeddings = true; - ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; - ctx_params.n_ubatch = ctx_params.n_batch; - } - - llama_context * ctx = llama_init_from_model(model, ctx_params); - if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } - - printf("Input prompt: \"%s\"\n", prompt.c_str()); - printf("Tokenized prompt (%d tokens): ", n_prompt); - for (auto id : prompt_tokens) { - char buf[128]; - int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); - if (n < 0) { - fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); - return 1; - } - std::string s(buf, n); - printf("%s (%d)", s.c_str(), id); - } - printf("\n"); - - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - float * data_ptr; - int data_size; - const char * type; - std::vector embd_out; - - if (embedding_mode) { - const int n_embd_out = llama_model_n_embd_out(model); - const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; - const int n_embeddings = n_embd_out * n_embd_count; - float * embeddings; - type = "-embeddings"; - - if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { - embeddings = llama_get_embeddings_seq(ctx, 0); - embd_out.resize(n_embeddings); - printf("Normalizing embeddings using norm: %d\n", embd_norm); - common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); - embeddings = embd_out.data(); - } else { - embeddings = llama_get_embeddings(ctx); - } - - printf("Embedding dimension: %d\n", n_embd_out); - printf("\n"); - - // Print embeddings in the specified format - for (int j = 0; j < n_embd_count; j++) { - printf("embedding %d: ", j); - - // Print first 3 values - for (int i = 0; i < 3 && i < n_embd_out; i++) { - printf("%9.6f ", embeddings[j * n_embd_out + i]); - } - - printf(" ... "); - - // Print last 3 values - for (int i = n_embd_out - 3; i < n_embd_out; i++) { - if (i >= 0) { - printf("%9.6f ", embeddings[j * n_embd_out + i]); - } - } - - printf("\n"); - } - printf("\n"); - - printf("Embeddings size: %d\n", n_embeddings); - - data_ptr = embeddings; - data_size = n_embeddings; - } else { - float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - const int n_logits = llama_vocab_n_tokens(vocab); - type = ""; - printf("Vocab size: %d\n", n_logits); - - data_ptr = logits; - data_size = n_logits; - } - - std::filesystem::create_directory("data"); - - // Save data to binary file - char bin_filename[512]; - snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving data to %s\n", bin_filename); - - FILE * f = fopen(bin_filename, "wb"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); - return 1; - } - fwrite(data_ptr, sizeof(float), data_size, f); - fclose(f); - - // Also save as text for debugging - char txt_filename[512]; - snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type); - f = fopen(txt_filename, "w"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open text output file\n", __func__); - return 1; - } - for (int i = 0; i < data_size; i++) { - fprintf(f, "%d: %.6f\n", i, data_ptr[i]); - } - fclose(f); - - if (!embedding_mode) { - printf("First 10 logits: "); - for (int i = 0; i < 10 && i < data_size; i++) { - printf("%.6f ", data_ptr[i]); - } - printf("\n"); - - printf("Last 10 logits: "); - for (int i = data_size - 10; i < data_size; i++) { - if (i >= 0) printf("%.6f ", data_ptr[i]); - } - printf("\n\n"); - } - - printf("Data saved to %s\n", bin_filename); - printf("Data saved to %s\n", txt_filename); - - llama_free(ctx); - llama_model_free(model); - - return 0; -} diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 894302c69..1a933207d 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -6,7 +6,7 @@ from pathlib import Path # Add utils directory to path for direct script execution sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) -from common import get_model_name_from_env_path # type: ignore[import-not-found] +from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found] def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -58,6 +58,13 @@ def main(): print("Checked all required files were found. Proceeding...\n") + # Verify tokens as they are a prerequisite for logits comparison. + print("🔍 Token Comparison Check") + print("=" * 40) + if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"): + print("\n❌ Token mismatch detected") + sys.exit(1) + print() print("🔍 GGML Model Validation for model ", model_name) print("=" * 40) diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py index 55ad82138..4ab778fbc 100755 --- a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py +++ b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py @@ -67,7 +67,7 @@ with torch.no_grad(): last_hidden_states = outputs.hidden_states[-1] # Get embeddings for all tokens - token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension + token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension print(f"Hidden states shape: {last_hidden_states.shape}") print(f"Token embeddings shape: {token_embeddings.shape}") diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh index fa16a02c6..3cce3fc94 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today" +../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index 529e9987b..b6c3d3866 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -21,6 +21,6 @@ fi echo $CONVERTED_MODEL echo $MODEL_TESTING_PROMPT -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" +../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index b12173a1f..215f1a9ee 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -7,12 +7,11 @@ import importlib import torch import numpy as np -from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig # Add parent directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) -from utils.common import debug_hook +from utils.common import debug_hook, save_output_data def parse_arguments(): parser = argparse.ArgumentParser(description="Process model with specified path") @@ -126,6 +125,7 @@ def main(): device = next(model.parameters()).device prompt = get_prompt(args) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + token_ids = input_ids[0].cpu().tolist() print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") @@ -151,19 +151,6 @@ def main(): print(f"Last token logits shape: {last_logits.shape}") print(f"Vocab size: {len(last_logits)}") - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}.bin" - txt_filename = data_dir / f"pytorch-{model_name}.txt" - - # Save to file for comparison - last_logits.astype(np.float32).tofile(bin_filename) - - # Also save as text file for easy inspection - with open(txt_filename, "w") as f: - for i, logit in enumerate(last_logits): - f.write(f"{i}: {logit:.6f}\n") - # Print some sample logits for quick verification print(f"First 10 logits: {last_logits[:10]}") print(f"Last 10 logits: {last_logits[-10:]}") @@ -175,8 +162,7 @@ def main(): token = tokenizer.decode([idx]) print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") - print(f"Saved bin logits to: {bin_filename}") - print(f"Saved txt logist to: {txt_filename}") + save_output_data(last_logits, token_ids, prompt, model_name) if __name__ == "__main__": main() diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 0f490e6c3..5d264b066 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -50,10 +50,9 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-logits -j8 -# TODO: update logits.cpp to accept a --file/-f option for the prompt +cmake --build ../../build --target llama-debug -j8 if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits else - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 774e5638f..0802cbcf4 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -3,13 +3,15 @@ import argparse import os import sys -import numpy as np import importlib -from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModel import torch +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from utils.common import save_output_data + def parse_arguments(): parser = argparse.ArgumentParser(description='Run original embedding model') @@ -169,6 +171,7 @@ def main(): return_tensors="pt" ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -185,6 +188,7 @@ def main(): ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -228,24 +232,11 @@ def main(): print() - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" - txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - flattened_embeddings = all_embeddings.flatten() - flattened_embeddings.astype(np.float32).tofile(bin_filename) - - with open(txt_filename, "w") as f: - idx = 0 - for j in range(n_embd_count): - for value in all_embeddings[j]: - f.write(f"{idx}: {value:.6f}\n") - idx += 1 print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print("") - print(f"Saved bin embeddings to: {bin_filename}") - print(f"Saved txt embeddings to: {txt_filename}") + + save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings") if __name__ == "__main__": diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py index 7595d0410..71761127b 100644 --- a/examples/model-conversion/scripts/utils/common.py +++ b/examples/model-conversion/scripts/utils/common.py @@ -3,6 +3,8 @@ import os import sys import torch +import numpy as np +from pathlib import Path def get_model_name_from_env_path(env_path_name): @@ -148,3 +150,96 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_ # Patch it setattr(module, function_name, debug_rope) print(f"RoPE debug patching applied to {model_module_path}.{function_name}") + + +def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"): + """ + Save output data (logits/embeddings), tokens, and prompt to files. + + Args: + data: numpy array of floats (logits or embeddings) + tokens: list or array of token IDs + prompt: string containing the input prompt + model_name: name of the model + type_suffix: optional suffix like "-embeddings" (default: "") + output_dir: directory to save files (default: "data") + + Creates the following files in output_dir: + - pytorch-{model_name}{type_suffix}.bin + - pytorch-{model_name}{type_suffix}.txt + - pytorch-{model_name}{type_suffix}-prompt.txt + - pytorch-{model_name}{type_suffix}-tokens.bin + """ + data_dir = Path(output_dir) + data_dir.mkdir(exist_ok=True) + base_path = data_dir / f"pytorch-{model_name}{type_suffix}" + + # Convert and flatten logits/embeddings + data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data) + data = data.flatten() if data.ndim > 1 else data + + # Save logits/embedding files + data.astype(np.float32).tofile(f"{base_path}.bin") + print(f"Data saved to {base_path}.bin") + + with open(f"{base_path}.txt", "w") as f: + f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data)) + print(f"Data saved to {base_path}.txt") + + # Convert and flatten tokens + tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens) + tokens = tokens.flatten() if tokens.ndim > 1 else tokens + + # Save token binary file + tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin") + print(f"Tokens saved to {base_path}-tokens.bin") + + # Save prompt file + with open(f"{base_path}-prompt.txt", "w") as f: + f.write(f"prompt: {prompt}\n") + f.write(f"n_tokens: {len(tokens)}\n") + f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n") + print(f"Prompt saved to {base_path}-prompt.txt") + + +def compare_tokens(original, converted, type_suffix="", output_dir="data"): + data_dir = Path(output_dir) + + # Read tokens from both models + tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin" + tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin" + + if not tokens1_file.exists(): + print(f"Error: Token file not found: {tokens1_file}") + return False + + if not tokens2_file.exists(): + print(f"Error: Token file not found: {tokens2_file}") + return False + + tokens1 = np.fromfile(tokens1_file, dtype=np.int32) + tokens2 = np.fromfile(tokens2_file, dtype=np.int32) + + print(f"\nComparing tokens between:") + print(f" Original : {original} ({len(tokens1)} tokens)") + print(f" Converted: {converted} ({len(tokens2)} tokens)") + + if len(tokens1) != len(tokens2): + print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}") + return False + + if np.array_equal(tokens1, tokens2): + print(f"\n✅ All {len(tokens1)} tokens match!") + return True + + mismatches = np.where(tokens1 != tokens2)[0] + print(f"\n❌ Found {len(mismatches)} mismatched tokens:") + + num_to_show = min(len(mismatches), 10) + for idx in mismatches[:num_to_show]: + print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}") + + if len(mismatches) > num_to_show: + print(f" ... and {len(mismatches) - num_to_show} more mismatches") + + return False diff --git a/examples/model-conversion/scripts/utils/compare_tokens.py b/examples/model-conversion/scripts/utils/compare_tokens.py new file mode 100755 index 000000000..a286cb568 --- /dev/null +++ b/examples/model-conversion/scripts/utils/compare_tokens.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import argparse +import sys +from common import compare_tokens # type: ignore + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description='Compare tokens between two models', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 + """ + ) + parser.add_argument( + 'original', + help='Original model name' + ) + parser.add_argument( + 'converted', + help='Converted model name' + ) + parser.add_argument( + '-s', '--suffix', + default='', + help='Type suffix (e.g., "-embeddings")' + ) + parser.add_argument( + '-d', '--data-dir', + default='data', + help='Directory containing token files (default: data)' + ) + parser.add_argument( + '-v', '--verbose', + action='store_true', + help='Print prompts from both models' + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + if args.verbose: + from pathlib import Path + data_dir = Path(args.data_dir) + + prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt" + prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt" + + if prompt1_file.exists(): + print(f"\nOriginal model prompt ({args.original}):") + print(f" {prompt1_file.read_text().strip()}") + + if prompt2_file.exists(): + print(f"\nConverted model prompt ({args.converted}):") + print(f" {prompt2_file.read_text().strip()}") + + print() + + result = compare_tokens( + args.original, + args.converted, + type_suffix=args.suffix, + output_dir=args.data_dir + ) + + # Enable the script to be used in shell scripts so that they can check + # the exit code for success/failure. + sys.exit(0 if result else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index e64c00049..38b03ce4d 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -4,8 +4,10 @@ import numpy as np import argparse import os import importlib +from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel +from common import compare_tokens # type: ignore[import-not-found] unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') @@ -157,9 +159,25 @@ def main(): else: prompt = args.prompt + python_emb_path = Path(args.python_embeddings) + cpp_emb_path = Path(args.cpp_embeddings) + + # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name") + python_model_name = python_emb_path.stem.replace("-embeddings", "") + cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "") + print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") print("=" * 70) + # First verify tokens match before comparing embeddings + print("\n🔍 Token Comparison Check") + print("=" * 70) + data_dir = python_emb_path.parent + if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)): + print("\n❌ Token mismatch detected") + exit(1) + print() + # Single prompt detailed comparison print(f"\nTesting with prompt: '{prompt}'") From 8c77a04cc723909eab5d3bc3ae14c82f4db1afc7 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 7 Jan 2026 10:13:17 +0000 Subject: [PATCH 13/19] vulkan: more mul mat optimizations (#18533) * q4_k * q5_k * q2_k * q4_1 * q5_1 * better buf index --- .../vulkan-shaders/dequant_funcs.glsl | 3 +- .../vulkan-shaders/mul_mm_funcs.glsl | 82 ++++++++++--------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 47 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 376944f1e..7865a6bda 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { - return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); + const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); + return dm; } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 1a3531761..ce7f2d699 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); @@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; - const uint ib = idx / 8; - const uint iqs = idx & 0x07; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint uint_qh = data_a_packed32[ib].qh; + const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10); + const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10); + const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10); + const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; + const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); + const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303)); const uint scales = data_a[ib].scales[scalesi]; const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy); + const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F; - const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4; - const vec2 q = vec2(unpack8(qs | qh).xy); + const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F; + const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; + const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 5b61ff9ca..bbdbf9dca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { From 03023296cf63f4177f51db9126b16b06f0e0af98 Mon Sep 17 00:00:00 2001 From: virajwad <84867530+virajwad@users.noreply.github.com> Date: Wed, 7 Jan 2026 02:59:47 -0800 Subject: [PATCH 14/19] vulkan: Warptile tuning for Intel Xe2/Xe3 (#18178) * modify warptile tuning for xe3 * intel vendor check w/ coopmat support * fix back formatting * fix formatting change 2 * move intel check to chip specific tuning part * Change to support both windows and linux * modify m_warptile to l_warptile for intel * modify warptile tuning for bf16 matmuls to fix regression (m_warptile to l_warptile) * Code style changes * Code style changes (2) * Code style changes (3) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3c13777b8..1f255b705 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2996,6 +2996,10 @@ static void ggml_vk_load_shaders(vk_device& device) { if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) { + // Xe2/Xe3 with coopmat enabled - warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -3678,6 +3682,11 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; + if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { + // Xe2/Xe3 - bf16 warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; + } + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5061,11 +5070,23 @@ static vk_device ggml_vk_get_device(size_t idx) { switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS case VK_VENDOR_ID_AMD: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; case VK_VENDOR_ID_INTEL: - device->mul_mat_l[i] = false; + if (!device->coopmat_support || device->architecture != INTEL_XE2) { + device->mul_mat_l[i] = false; + device->mul_mat_id_l[i] = false; + } else { + device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel + device->mul_mat_id_l[i] = true; + } device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; - device->mul_mat_id_l[i] = false; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; From ca4a8370bc1ebf267073cfa29067ebeff7ab8015 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 7 Jan 2026 05:03:32 -0600 Subject: [PATCH 15/19] vulkan: reject ops when a tensor is too large to allocate (#18646) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 29 +++++++++++++--------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1f255b705..d68735a04 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14305,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const } static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + // reject any tensors larger than the max buffer size + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + return false; + } + } + if (ggml_nbytes(op) > device->max_buffer_size) { + return false; + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -14353,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); if (op->op == GGML_OP_MUL_MAT_ID) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU @@ -14415,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; uint32_t HSK = op->src[1]->ne[0]; uint32_t HSV = op->src[2]->ne[0]; @@ -14638,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { return true; @@ -14652,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // We could potentially support larger, using argsort to sort the // whole thing. Not clear if this is needed. uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; @@ -14700,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); } @@ -14709,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_SOLVE_TRI: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -14776,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - const uint32_t SPLIT_H = 16; size_t stateC_size = SPLIT_H * d_state * sizeof(float); From 9dfa8ee950b077b2d8a49caaa144dcc6bbc55305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 7 Jan 2026 13:07:08 +0100 Subject: [PATCH 16/19] ci : run cann build unconditionally [no ci] (#18659) --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1193779d0..85601b371 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1418,7 +1418,6 @@ jobs: echo "FIXME: test on devices" openEuler-latest-cmake-cann: - if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: run: shell: bash -el {0} From bb77764c2d024a6fecc5bdeb3618cb580ee15041 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 7 Jan 2026 13:18:53 +0100 Subject: [PATCH 17/19] convert : clarify sentence-transformers-dense-modules help [no ci] (#18662) * convert : clarify sentence-transformers-dense-modules help [no ci] This commit updates this options help message which currently looks like this: ```console --sentence-transformers-dense-modules Whether to include sentence-transformers dense modules.It can be used for sentence-transformers models, like google/embeddinggemma-300mDefault these modules are not included. ``` --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d9ee390b3..0a8bac0e2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10974,8 +10974,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--sentence-transformers-dense-modules", action="store_true", - help=("Whether to include sentence-transformers dense modules." - "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + help=("Whether to include sentence-transformers dense modules. " + "It can be used for sentence-transformers models, like google/embeddinggemma-300m. " "Default these modules are not included.") ) From 56426673cb950feaff28c466c7cf38ac4c165742 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 7 Jan 2026 15:16:20 +0200 Subject: [PATCH 18/19] scripts : add pr2wt.sh (#18644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * scripts : add pr2wt.sh * script : shebang Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- .gitignore | 1 + scripts/pr2wt.sh | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100755 scripts/pr2wt.sh diff --git a/.gitignore b/.gitignore index 05eb578a8..bb122d692 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ poetry.toml # Local scripts /run-vim.sh /run-chat.sh +/run-spec.sh /.ccache/ # IDE diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh new file mode 100755 index 000000000..22251339a --- /dev/null +++ b/scripts/pr2wt.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# intialize a new worktree from a PR number: +# +# - creates a new remote using the fork's clone URL +# - creates a local branch tracking the remote branch +# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}" +# +# sample usage: +# ./scripts/pr2wt.sh 12345 +# ./scripts/pr2wt.sh 12345 opencode + +function usage() { + echo "usage: $0 [cmd]" + exit 1 +} + +# check we are in the right directory +if [[ ! -f "scripts/pr2wt.sh" ]]; then + echo "error: this script must be run from the root of the repository" + exit 1 +fi + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +PR=$1 +[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } + +url_origin=$(git config --get remote.origin.url) || { + echo "error: no remote named 'origin' in this repository" + exit 1 +} + +org_repo=$(echo $url_origin | cut -d/ -f4-) + +echo "org/repo: $org_repo" + +meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}") + +url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url') +head_ref=$(echo "$meta" | jq -r '.head.ref') + +echo "url: $url_remote" +echo "head_ref: $head_ref" + +git remote rm pr/${PR} +git remote add pr/${PR} $url_remote +git fetch pr/${PR} $head_ref + +dir=$(basename $(pwd)) + +git branch -D pr/$PR 2> /dev/null +git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null + +wt_path=$(cd ../$dir-pr-$PR && pwd) + +echo "git worktree created in $wt_path" + +# if a command was provided, execute it +if [[ $# -eq 2 ]]; then + cd ../$dir-pr-$PR + exec $2 +fi From 56d2fed2b3970ae55eebd0e5426d402304b1358a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 7 Jan 2026 16:18:26 +0100 Subject: [PATCH 19/19] tools : remove llama-run (#18661) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tools : remove llama-run * Remove licenses/LICENSE-linenoise Signed-off-by: Adrien Gallouët --- README.md | 16 - common/arg.cpp | 1 - licenses/LICENSE-linenoise | 26 - tools/CMakeLists.txt | 1 - tools/run/CMakeLists.txt | 23 - tools/run/README.md | 52 - tools/run/linenoise.cpp/linenoise.cpp | 1995 ------------------------- tools/run/linenoise.cpp/linenoise.h | 137 -- tools/run/run.cpp | 1408 ----------------- 9 files changed, 3659 deletions(-) delete mode 100644 licenses/LICENSE-linenoise delete mode 100644 tools/run/CMakeLists.txt delete mode 100644 tools/run/README.md delete mode 100644 tools/run/linenoise.cpp/linenoise.cpp delete mode 100644 tools/run/linenoise.cpp/linenoise.h delete mode 100644 tools/run/run.cpp diff --git a/README.md b/README.md index ed956bb02..e59612f7a 100644 --- a/README.md +++ b/README.md @@ -482,21 +482,6 @@ To learn more about model quantization, [read this documentation](tools/quantize -## [`llama-run`](tools/run) - -#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. - --
- Run a model with a specific prompt (by default it's pulled from Ollama registry) - - ```bash - llama-run granite-code - ``` - -
- -[^3]: [RamaLama](https://github.com/containers/ramalama) - ## [`llama-simple`](examples/simple) #### A minimal example for implementing apps with `llama.cpp`. Useful for developers. @@ -600,7 +585,6 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain - [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License - [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License -- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) - [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain - [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain diff --git a/common/arg.cpp b/common/arg.cpp index a67a26e2d..e7966d9d5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -679,7 +679,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-quantize", "llama-qwen2vl-cli", "llama-retrieval", - "llama-run", "llama-save-load-state", "llama-server", "llama-simple", diff --git a/licenses/LICENSE-linenoise b/licenses/LICENSE-linenoise deleted file mode 100644 index b006b3b24..000000000 --- a/licenses/LICENSE-linenoise +++ /dev/null @@ -1,26 +0,0 @@ -Copyright (c) 2010-2014, Salvatore Sanfilippo -Copyright (c) 2010-2013, Pieter Noordhuis -Copyright (c) 2025, Eric Curtin - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8df3f4100..48959fefb 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -25,7 +25,6 @@ else() if (LLAMA_BUILD_SERVER) add_subdirectory(server) endif() - add_subdirectory(run) add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(mtmd) diff --git a/tools/run/CMakeLists.txt b/tools/run/CMakeLists.txt deleted file mode 100644 index 6ad7534e2..000000000 --- a/tools/run/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -set(TARGET llama-run) -add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp) - -# TODO: avoid copying this code block from common/CMakeLists.txt -set(LLAMA_RUN_EXTRA_LIBS "") -if (LLAMA_CURL) - find_package(CURL REQUIRED) - target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) - include_directories(${CURL_INCLUDE_DIRS}) - set(LLAMA_RUN_EXTRA_LIBS ${LLAMA_RUN_EXTRA_LIBS} ${CURL_LIBRARIES}) -endif () - -if(LLAMA_TOOLS_INSTALL) - install(TARGETS ${TARGET} RUNTIME) -endif() - -if (CMAKE_SYSTEM_NAME MATCHES "AIX") - # AIX's flock() function comes from libbsd.a - target_link_libraries(${TARGET} PRIVATE -lbsd) -endif() - -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/run/README.md b/tools/run/README.md deleted file mode 100644 index 5fd769b44..000000000 --- a/tools/run/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# llama.cpp/example/run - -The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. - -```bash -llama-run granite3-moe -``` - -```bash -Description: - Runs a llm - -Usage: - llama-run [options] model [prompt] - -Options: - -c, --context-size - Context size (default: 2048) - -n, -ngl, --ngl - Number of GPU layers (default: 0) - --temp - Temperature (default: 0.8) - -v, --verbose, --log-verbose - Set verbosity level to infinity (i.e. log all messages, useful for debugging) - -h, --help - Show help message - -Commands: - model - Model is a string with an optional prefix of - huggingface:// (hf://), ollama://, https:// or file://. - If no protocol is specified and a file exists in the specified - path, file:// is assumed, otherwise if a file does not exist in - the specified path, ollama:// is assumed. Models that are being - pulled are downloaded with .partial extension while being - downloaded and then renamed as the file without the .partial - extension when complete. - -Examples: - llama-run llama3 - llama-run ollama://granite-code - llama-run ollama://smollm:135m - llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run https://example.com/some-file1.gguf - llama-run some-file2.gguf - llama-run file://some-file3.gguf - llama-run --ngl 999 some-file4.gguf - llama-run --ngl 999 some-file5.gguf Hello World -``` diff --git a/tools/run/linenoise.cpp/linenoise.cpp b/tools/run/linenoise.cpp/linenoise.cpp deleted file mode 100644 index 9cb939900..000000000 --- a/tools/run/linenoise.cpp/linenoise.cpp +++ /dev/null @@ -1,1995 +0,0 @@ -#ifndef _WIN32 -/* - * You can find the latest source code at: - * - * http://github.com/ericcurtin/linenoise.cpp - * - * Does a number of crazy assumptions that happen to be true in 99.9999% of - * the 2010 UNIX computers around. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * ------------------------------------------------------------------------ - * - * References: - * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html - * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html - * - * Todo list: - * - Filter bogus Ctrl+ combinations. - * - Win32 support - * - * Bloat: - * - History search like Ctrl+r in readline? - * - * List of escape sequences used by this program, we do everything just - * with three sequences. In order to be so cheap we may have some - * flickering effect with some slow terminal, but the lesser sequences - * the more compatible. - * - * EL (Erase Line) - * Sequence: ESC [ n K - * Effect: if n is 0 or missing, clear from cursor to end of line - * Effect: if n is 1, clear from beginning of line to cursor - * Effect: if n is 2, clear entire line - * - * CUF (CUrsor Forward) - * Sequence: ESC [ n C - * Effect: moves cursor forward n chars - * - * CUB (CUrsor Backward) - * Sequence: ESC [ n D - * Effect: moves cursor backward n chars - * - * The following is used to get the terminal width if getting - * the width with the TIOCGWINSZ ioctl fails - * - * DSR (Device Status Report) - * Sequence: ESC [ 6 n - * Effect: reports the current cursor position as ESC [ n ; m R - * where n is the row and m is the column - * - * When multi line mode is enabled, we also use an additional escape - * sequence. However multi line editing is disabled by default. - * - * CUU (Cursor Up) - * Sequence: ESC [ n A - * Effect: moves cursor up of n chars. - * - * CUD (Cursor Down) - * Sequence: ESC [ n B - * Effect: moves cursor down of n chars. - * - * When linenoiseClearScreen() is called, two additional escape sequences - * are used in order to clear the screen and position the cursor at home - * position. - * - * CUP (Cursor position) - * Sequence: ESC [ H - * Effect: moves the cursor to upper left corner - * - * ED (Erase display) - * Sequence: ESC [ 2 J - * Effect: clear the whole screen - * - */ - -# include "linenoise.h" - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include -# include -# include - -# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -# define LINENOISE_MAX_LINE 4096 -static std::vector unsupported_term = { "dumb", "cons25", "emacs" }; -static linenoiseCompletionCallback *completionCallback = NULL; -static linenoiseHintsCallback *hintsCallback = NULL; -static linenoiseFreeHintsCallback *freeHintsCallback = NULL; -static char *linenoiseNoTTY(void); -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); -static void refreshLineWithFlags(struct linenoiseState *l, int flags); - -static struct termios orig_termios; /* In order to restore at exit.*/ -static int maskmode = 0; /* Show "***" instead of input. For passwords. */ -static int rawmode = 0; /* For atexit() function to check if restore is needed*/ -static int mlmode = 0; /* Multi line mode. Default is single line. */ -static int atexit_registered = 0; /* Register atexit just 1 time. */ -static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; -static int history_len = 0; -static char **history = NULL; - -enum KEY_ACTION{ - KEY_NULL = 0, /* NULL */ - CTRL_A = 1, /* Ctrl+a */ - CTRL_B = 2, /* Ctrl-b */ - CTRL_C = 3, /* Ctrl-c */ - CTRL_D = 4, /* Ctrl-d */ - CTRL_E = 5, /* Ctrl-e */ - CTRL_F = 6, /* Ctrl-f */ - CTRL_H = 8, /* Ctrl-h */ - TAB = 9, /* Tab */ - CTRL_K = 11, /* Ctrl+k */ - CTRL_L = 12, /* Ctrl+l */ - ENTER = 13, /* Enter */ - CTRL_N = 14, /* Ctrl-n */ - CTRL_P = 16, /* Ctrl-p */ - CTRL_T = 20, /* Ctrl-t */ - CTRL_U = 21, /* Ctrl+u */ - CTRL_W = 23, /* Ctrl+w */ - ESC = 27, /* Escape */ - BACKSPACE = 127 /* Backspace */ -}; - -static void linenoiseAtExit(void); -int linenoiseHistoryAdd(const char *line); -#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen -#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. -#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. -static void refreshLine(struct linenoiseState *l); - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } - } - - return 0; - } - - ~File() { - if (fd >= 0) { - flock(fd, LOCK_UN); - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -}; - -#if 0 -/* Debugging function. */ -__attribute__((format(printf, 1, 2))) -static void lndebug(const char *fmt, ...) { - static File file; - if (file.file == nullptr) { - file.open("/tmp/lndebug.txt", "a"); - } - - if (file.file != nullptr) { - va_list args; - va_start(args, fmt); - vfprintf(file.file, fmt, args); - va_end(args); - fflush(file.file); - } -} -#endif - -/* ========================== Encoding functions ============================= */ - -/* Get length of previous UTF8 codepoint */ -static size_t prevUtf8CodePointLen(const char * buf, int pos) { - int end = pos--; - while (pos >= 0 && ((unsigned char) buf[pos] & 0xC0) == 0x80) { - pos--; - } - return end - pos; -} - -/* Convert UTF8 to Unicode code point */ -static size_t utf8BytesToCodePoint(const char * buf, size_t len, int * cp) { - if (len) { - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - *cp = byte; - return 1; - } else if ((byte & 0xE0) == 0xC0) { - if (len >= 2) { - *cp = (((unsigned long) (buf[0] & 0x1F)) << 6) | ((unsigned long) (buf[1] & 0x3F)); - return 2; - } - } else if ((byte & 0xF0) == 0xE0) { - if (len >= 3) { - *cp = (((unsigned long) (buf[0] & 0x0F)) << 12) | (((unsigned long) (buf[1] & 0x3F)) << 6) | - ((unsigned long) (buf[2] & 0x3F)); - return 3; - } - } else if ((byte & 0xF8) == 0xF0) { - if (len >= 4) { - *cp = (((unsigned long) (buf[0] & 0x07)) << 18) | (((unsigned long) (buf[1] & 0x3F)) << 12) | - (((unsigned long) (buf[2] & 0x3F)) << 6) | ((unsigned long) (buf[3] & 0x3F)); - return 4; - } - } - } - return 0; -} - -/* Check if the code is a wide character */ -static const unsigned long wideCharTable[][2] = { - /* BEGIN: WIDE CHAR TABLE */ - { 0x1100, 0x115F }, - { 0x231A, 0x231B }, - { 0x2329, 0x232A }, - { 0x23E9, 0x23EC }, - { 0x23F0, 0x23F0 }, - { 0x23F3, 0x23F3 }, - { 0x25FD, 0x25FE }, - { 0x2614, 0x2615 }, - { 0x2630, 0x2637 }, - { 0x2648, 0x2653 }, - { 0x267F, 0x267F }, - { 0x268A, 0x268F }, - { 0x2693, 0x2693 }, - { 0x26A1, 0x26A1 }, - { 0x26AA, 0x26AB }, - { 0x26BD, 0x26BE }, - { 0x26C4, 0x26C5 }, - { 0x26CE, 0x26CE }, - { 0x26D4, 0x26D4 }, - { 0x26EA, 0x26EA }, - { 0x26F2, 0x26F3 }, - { 0x26F5, 0x26F5 }, - { 0x26FA, 0x26FA }, - { 0x26FD, 0x26FD }, - { 0x2705, 0x2705 }, - { 0x270A, 0x270B }, - { 0x2728, 0x2728 }, - { 0x274C, 0x274C }, - { 0x274E, 0x274E }, - { 0x2753, 0x2755 }, - { 0x2757, 0x2757 }, - { 0x2795, 0x2797 }, - { 0x27B0, 0x27B0 }, - { 0x27BF, 0x27BF }, - { 0x2B1B, 0x2B1C }, - { 0x2B50, 0x2B50 }, - { 0x2B55, 0x2B55 }, - { 0x2E80, 0x2E99 }, - { 0x2E9B, 0x2EF3 }, - { 0x2F00, 0x2FD5 }, - { 0x2FF0, 0x303E }, - { 0x3041, 0x3096 }, - { 0x3099, 0x30FF }, - { 0x3105, 0x312F }, - { 0x3131, 0x318E }, - { 0x3190, 0x31E5 }, - { 0x31EF, 0x321E }, - { 0x3220, 0x3247 }, - { 0x3250, 0xA48C }, - { 0xA490, 0xA4C6 }, - { 0xA960, 0xA97C }, - { 0xAC00, 0xD7A3 }, - { 0xF900, 0xFAFF }, - { 0xFE10, 0xFE19 }, - { 0xFE30, 0xFE52 }, - { 0xFE54, 0xFE66 }, - { 0xFE68, 0xFE6B }, - { 0xFF01, 0xFF60 }, - { 0xFFE0, 0xFFE6 }, - { 0x16FE0, 0x16FE4 }, - { 0x16FF0, 0x16FF1 }, - { 0x17000, 0x187F7 }, - { 0x18800, 0x18CD5 }, - { 0x18CFF, 0x18D08 }, - { 0x1AFF0, 0x1AFF3 }, - { 0x1AFF5, 0x1AFFB }, - { 0x1AFFD, 0x1AFFE }, - { 0x1B000, 0x1B122 }, - { 0x1B132, 0x1B132 }, - { 0x1B150, 0x1B152 }, - { 0x1B155, 0x1B155 }, - { 0x1B164, 0x1B167 }, - { 0x1B170, 0x1B2FB }, - { 0x1D300, 0x1D356 }, - { 0x1D360, 0x1D376 }, - { 0x1F004, 0x1F004 }, - { 0x1F0CF, 0x1F0CF }, - { 0x1F18E, 0x1F18E }, - { 0x1F191, 0x1F19A }, - { 0x1F200, 0x1F202 }, - { 0x1F210, 0x1F23B }, - { 0x1F240, 0x1F248 }, - { 0x1F250, 0x1F251 }, - { 0x1F260, 0x1F265 }, - { 0x1F300, 0x1F320 }, - { 0x1F32D, 0x1F335 }, - { 0x1F337, 0x1F37C }, - { 0x1F37E, 0x1F393 }, - { 0x1F3A0, 0x1F3CA }, - { 0x1F3CF, 0x1F3D3 }, - { 0x1F3E0, 0x1F3F0 }, - { 0x1F3F4, 0x1F3F4 }, - { 0x1F3F8, 0x1F43E }, - { 0x1F440, 0x1F440 }, - { 0x1F442, 0x1F4FC }, - { 0x1F4FF, 0x1F53D }, - { 0x1F54B, 0x1F54E }, - { 0x1F550, 0x1F567 }, - { 0x1F57A, 0x1F57A }, - { 0x1F595, 0x1F596 }, - { 0x1F5A4, 0x1F5A4 }, - { 0x1F5FB, 0x1F64F }, - { 0x1F680, 0x1F6C5 }, - { 0x1F6CC, 0x1F6CC }, - { 0x1F6D0, 0x1F6D2 }, - { 0x1F6D5, 0x1F6D7 }, - { 0x1F6DC, 0x1F6DF }, - { 0x1F6EB, 0x1F6EC }, - { 0x1F6F4, 0x1F6FC }, - { 0x1F7E0, 0x1F7EB }, - { 0x1F7F0, 0x1F7F0 }, - { 0x1F90C, 0x1F93A }, - { 0x1F93C, 0x1F945 }, - { 0x1F947, 0x1F9FF }, - { 0x1FA70, 0x1FA7C }, - { 0x1FA80, 0x1FA89 }, - { 0x1FA8F, 0x1FAC6 }, - { 0x1FACE, 0x1FADC }, - { 0x1FADF, 0x1FAE9 }, - { 0x1FAF0, 0x1FAF8 }, - { 0x20000, 0x2FFFD }, - { 0x30000, 0x3FFFD } - /* END: WIDE CHAR TABLE */ -}; - -static const size_t wideCharTableSize = sizeof(wideCharTable) / sizeof(wideCharTable[0]); - -static bool isWideChar(unsigned long cp) { - for (size_t i = 0; i < wideCharTableSize; i++) { - auto first_code = wideCharTable[i][0]; - auto last_code = wideCharTable[i][1]; - if (first_code > cp) { - return false; - } - if (first_code <= cp && cp <= last_code) { - return true; - } - } - return false; -} - -/* Check if the code is a combining character */ -static const unsigned long combiningCharTable[] = { - /* BEGIN: COMBINING CHAR TABLE */ - 0x0300, 0x0301, 0x0302, 0x0303, 0x0304, 0x0305, 0x0306, 0x0307, 0x0308, 0x0309, 0x030A, 0x030B, 0x030C, - 0x030D, 0x030E, 0x030F, 0x0310, 0x0311, 0x0312, 0x0313, 0x0314, 0x0315, 0x0316, 0x0317, 0x0318, 0x0319, - 0x031A, 0x031B, 0x031C, 0x031D, 0x031E, 0x031F, 0x0320, 0x0321, 0x0322, 0x0323, 0x0324, 0x0325, 0x0326, - 0x0327, 0x0328, 0x0329, 0x032A, 0x032B, 0x032C, 0x032D, 0x032E, 0x032F, 0x0330, 0x0331, 0x0332, 0x0333, - 0x0334, 0x0335, 0x0336, 0x0337, 0x0338, 0x0339, 0x033A, 0x033B, 0x033C, 0x033D, 0x033E, 0x033F, 0x0340, - 0x0341, 0x0342, 0x0343, 0x0344, 0x0345, 0x0346, 0x0347, 0x0348, 0x0349, 0x034A, 0x034B, 0x034C, 0x034D, - 0x034E, 0x034F, 0x0350, 0x0351, 0x0352, 0x0353, 0x0354, 0x0355, 0x0356, 0x0357, 0x0358, 0x0359, 0x035A, - 0x035B, 0x035C, 0x035D, 0x035E, 0x035F, 0x0360, 0x0361, 0x0362, 0x0363, 0x0364, 0x0365, 0x0366, 0x0367, - 0x0368, 0x0369, 0x036A, 0x036B, 0x036C, 0x036D, 0x036E, 0x036F, 0x0483, 0x0484, 0x0485, 0x0486, 0x0487, - 0x0591, 0x0592, 0x0593, 0x0594, 0x0595, 0x0596, 0x0597, 0x0598, 0x0599, 0x059A, 0x059B, 0x059C, 0x059D, - 0x059E, 0x059F, 0x05A0, 0x05A1, 0x05A2, 0x05A3, 0x05A4, 0x05A5, 0x05A6, 0x05A7, 0x05A8, 0x05A9, 0x05AA, - 0x05AB, 0x05AC, 0x05AD, 0x05AE, 0x05AF, 0x05B0, 0x05B1, 0x05B2, 0x05B3, 0x05B4, 0x05B5, 0x05B6, 0x05B7, - 0x05B8, 0x05B9, 0x05BA, 0x05BB, 0x05BC, 0x05BD, 0x05BF, 0x05C1, 0x05C2, 0x05C4, 0x05C5, 0x05C7, 0x0610, - 0x0611, 0x0612, 0x0613, 0x0614, 0x0615, 0x0616, 0x0617, 0x0618, 0x0619, 0x061A, 0x064B, 0x064C, 0x064D, - 0x064E, 0x064F, 0x0650, 0x0651, 0x0652, 0x0653, 0x0654, 0x0655, 0x0656, 0x0657, 0x0658, 0x0659, 0x065A, - 0x065B, 0x065C, 0x065D, 0x065E, 0x065F, 0x0670, 0x06D6, 0x06D7, 0x06D8, 0x06D9, 0x06DA, 0x06DB, 0x06DC, - 0x06DF, 0x06E0, 0x06E1, 0x06E2, 0x06E3, 0x06E4, 0x06E7, 0x06E8, 0x06EA, 0x06EB, 0x06EC, 0x06ED, 0x0711, - 0x0730, 0x0731, 0x0732, 0x0733, 0x0734, 0x0735, 0x0736, 0x0737, 0x0738, 0x0739, 0x073A, 0x073B, 0x073C, - 0x073D, 0x073E, 0x073F, 0x0740, 0x0741, 0x0742, 0x0743, 0x0744, 0x0745, 0x0746, 0x0747, 0x0748, 0x0749, - 0x074A, 0x07A6, 0x07A7, 0x07A8, 0x07A9, 0x07AA, 0x07AB, 0x07AC, 0x07AD, 0x07AE, 0x07AF, 0x07B0, 0x07EB, - 0x07EC, 0x07ED, 0x07EE, 0x07EF, 0x07F0, 0x07F1, 0x07F2, 0x07F3, 0x07FD, 0x0816, 0x0817, 0x0818, 0x0819, - 0x081B, 0x081C, 0x081D, 0x081E, 0x081F, 0x0820, 0x0821, 0x0822, 0x0823, 0x0825, 0x0826, 0x0827, 0x0829, - 0x082A, 0x082B, 0x082C, 0x082D, 0x0859, 0x085A, 0x085B, 0x0897, 0x0898, 0x0899, 0x089A, 0x089B, 0x089C, - 0x089D, 0x089E, 0x089F, 0x08CA, 0x08CB, 0x08CC, 0x08CD, 0x08CE, 0x08CF, 0x08D0, 0x08D1, 0x08D2, 0x08D3, - 0x08D4, 0x08D5, 0x08D6, 0x08D7, 0x08D8, 0x08D9, 0x08DA, 0x08DB, 0x08DC, 0x08DD, 0x08DE, 0x08DF, 0x08E0, - 0x08E1, 0x08E3, 0x08E4, 0x08E5, 0x08E6, 0x08E7, 0x08E8, 0x08E9, 0x08EA, 0x08EB, 0x08EC, 0x08ED, 0x08EE, - 0x08EF, 0x08F0, 0x08F1, 0x08F2, 0x08F3, 0x08F4, 0x08F5, 0x08F6, 0x08F7, 0x08F8, 0x08F9, 0x08FA, 0x08FB, - 0x08FC, 0x08FD, 0x08FE, 0x08FF, 0x0900, 0x0901, 0x0902, 0x093A, 0x093C, 0x0941, 0x0942, 0x0943, 0x0944, - 0x0945, 0x0946, 0x0947, 0x0948, 0x094D, 0x0951, 0x0952, 0x0953, 0x0954, 0x0955, 0x0956, 0x0957, 0x0962, - 0x0963, 0x0981, 0x09BC, 0x09C1, 0x09C2, 0x09C3, 0x09C4, 0x09CD, 0x09E2, 0x09E3, 0x09FE, 0x0A01, 0x0A02, - 0x0A3C, 0x0A41, 0x0A42, 0x0A47, 0x0A48, 0x0A4B, 0x0A4C, 0x0A4D, 0x0A51, 0x0A70, 0x0A71, 0x0A75, 0x0A81, - 0x0A82, 0x0ABC, 0x0AC1, 0x0AC2, 0x0AC3, 0x0AC4, 0x0AC5, 0x0AC7, 0x0AC8, 0x0ACD, 0x0AE2, 0x0AE3, 0x0AFA, - 0x0AFB, 0x0AFC, 0x0AFD, 0x0AFE, 0x0AFF, 0x0B01, 0x0B3C, 0x0B3F, 0x0B41, 0x0B42, 0x0B43, 0x0B44, 0x0B4D, - 0x0B55, 0x0B56, 0x0B62, 0x0B63, 0x0B82, 0x0BC0, 0x0BCD, 0x0C00, 0x0C04, 0x0C3C, 0x0C3E, 0x0C3F, 0x0C40, - 0x0C46, 0x0C47, 0x0C48, 0x0C4A, 0x0C4B, 0x0C4C, 0x0C4D, 0x0C55, 0x0C56, 0x0C62, 0x0C63, 0x0C81, 0x0CBC, - 0x0CBF, 0x0CC6, 0x0CCC, 0x0CCD, 0x0CE2, 0x0CE3, 0x0D00, 0x0D01, 0x0D3B, 0x0D3C, 0x0D41, 0x0D42, 0x0D43, - 0x0D44, 0x0D4D, 0x0D62, 0x0D63, 0x0D81, 0x0DCA, 0x0DD2, 0x0DD3, 0x0DD4, 0x0DD6, 0x0E31, 0x0E34, 0x0E35, - 0x0E36, 0x0E37, 0x0E38, 0x0E39, 0x0E3A, 0x0E47, 0x0E48, 0x0E49, 0x0E4A, 0x0E4B, 0x0E4C, 0x0E4D, 0x0E4E, - 0x0EB1, 0x0EB4, 0x0EB5, 0x0EB6, 0x0EB7, 0x0EB8, 0x0EB9, 0x0EBA, 0x0EBB, 0x0EBC, 0x0EC8, 0x0EC9, 0x0ECA, - 0x0ECB, 0x0ECC, 0x0ECD, 0x0ECE, 0x0F18, 0x0F19, 0x0F35, 0x0F37, 0x0F39, 0x0F71, 0x0F72, 0x0F73, 0x0F74, - 0x0F75, 0x0F76, 0x0F77, 0x0F78, 0x0F79, 0x0F7A, 0x0F7B, 0x0F7C, 0x0F7D, 0x0F7E, 0x0F80, 0x0F81, 0x0F82, - 0x0F83, 0x0F84, 0x0F86, 0x0F87, 0x0F8D, 0x0F8E, 0x0F8F, 0x0F90, 0x0F91, 0x0F92, 0x0F93, 0x0F94, 0x0F95, - 0x0F96, 0x0F97, 0x0F99, 0x0F9A, 0x0F9B, 0x0F9C, 0x0F9D, 0x0F9E, 0x0F9F, 0x0FA0, 0x0FA1, 0x0FA2, 0x0FA3, - 0x0FA4, 0x0FA5, 0x0FA6, 0x0FA7, 0x0FA8, 0x0FA9, 0x0FAA, 0x0FAB, 0x0FAC, 0x0FAD, 0x0FAE, 0x0FAF, 0x0FB0, - 0x0FB1, 0x0FB2, 0x0FB3, 0x0FB4, 0x0FB5, 0x0FB6, 0x0FB7, 0x0FB8, 0x0FB9, 0x0FBA, 0x0FBB, 0x0FBC, 0x0FC6, - 0x102D, 0x102E, 0x102F, 0x1030, 0x1032, 0x1033, 0x1034, 0x1035, 0x1036, 0x1037, 0x1039, 0x103A, 0x103D, - 0x103E, 0x1058, 0x1059, 0x105E, 0x105F, 0x1060, 0x1071, 0x1072, 0x1073, 0x1074, 0x1082, 0x1085, 0x1086, - 0x108D, 0x109D, 0x135D, 0x135E, 0x135F, 0x1712, 0x1713, 0x1714, 0x1732, 0x1733, 0x1752, 0x1753, 0x1772, - 0x1773, 0x17B4, 0x17B5, 0x17B7, 0x17B8, 0x17B9, 0x17BA, 0x17BB, 0x17BC, 0x17BD, 0x17C6, 0x17C9, 0x17CA, - 0x17CB, 0x17CC, 0x17CD, 0x17CE, 0x17CF, 0x17D0, 0x17D1, 0x17D2, 0x17D3, 0x17DD, 0x180B, 0x180C, 0x180D, - 0x180F, 0x1885, 0x1886, 0x18A9, 0x1920, 0x1921, 0x1922, 0x1927, 0x1928, 0x1932, 0x1939, 0x193A, 0x193B, - 0x1A17, 0x1A18, 0x1A1B, 0x1A56, 0x1A58, 0x1A59, 0x1A5A, 0x1A5B, 0x1A5C, 0x1A5D, 0x1A5E, 0x1A60, 0x1A62, - 0x1A65, 0x1A66, 0x1A67, 0x1A68, 0x1A69, 0x1A6A, 0x1A6B, 0x1A6C, 0x1A73, 0x1A74, 0x1A75, 0x1A76, 0x1A77, - 0x1A78, 0x1A79, 0x1A7A, 0x1A7B, 0x1A7C, 0x1A7F, 0x1AB0, 0x1AB1, 0x1AB2, 0x1AB3, 0x1AB4, 0x1AB5, 0x1AB6, - 0x1AB7, 0x1AB8, 0x1AB9, 0x1ABA, 0x1ABB, 0x1ABC, 0x1ABD, 0x1ABF, 0x1AC0, 0x1AC1, 0x1AC2, 0x1AC3, 0x1AC4, - 0x1AC5, 0x1AC6, 0x1AC7, 0x1AC8, 0x1AC9, 0x1ACA, 0x1ACB, 0x1ACC, 0x1ACD, 0x1ACE, 0x1B00, 0x1B01, 0x1B02, - 0x1B03, 0x1B34, 0x1B36, 0x1B37, 0x1B38, 0x1B39, 0x1B3A, 0x1B3C, 0x1B42, 0x1B6B, 0x1B6C, 0x1B6D, 0x1B6E, - 0x1B6F, 0x1B70, 0x1B71, 0x1B72, 0x1B73, 0x1B80, 0x1B81, 0x1BA2, 0x1BA3, 0x1BA4, 0x1BA5, 0x1BA8, 0x1BA9, - 0x1BAB, 0x1BAC, 0x1BAD, 0x1BE6, 0x1BE8, 0x1BE9, 0x1BED, 0x1BEF, 0x1BF0, 0x1BF1, 0x1C2C, 0x1C2D, 0x1C2E, - 0x1C2F, 0x1C30, 0x1C31, 0x1C32, 0x1C33, 0x1C36, 0x1C37, 0x1CD0, 0x1CD1, 0x1CD2, 0x1CD4, 0x1CD5, 0x1CD6, - 0x1CD7, 0x1CD8, 0x1CD9, 0x1CDA, 0x1CDB, 0x1CDC, 0x1CDD, 0x1CDE, 0x1CDF, 0x1CE0, 0x1CE2, 0x1CE3, 0x1CE4, - 0x1CE5, 0x1CE6, 0x1CE7, 0x1CE8, 0x1CED, 0x1CF4, 0x1CF8, 0x1CF9, 0x1DC0, 0x1DC1, 0x1DC2, 0x1DC3, 0x1DC4, - 0x1DC5, 0x1DC6, 0x1DC7, 0x1DC8, 0x1DC9, 0x1DCA, 0x1DCB, 0x1DCC, 0x1DCD, 0x1DCE, 0x1DCF, 0x1DD0, 0x1DD1, - 0x1DD2, 0x1DD3, 0x1DD4, 0x1DD5, 0x1DD6, 0x1DD7, 0x1DD8, 0x1DD9, 0x1DDA, 0x1DDB, 0x1DDC, 0x1DDD, 0x1DDE, - 0x1DDF, 0x1DE0, 0x1DE1, 0x1DE2, 0x1DE3, 0x1DE4, 0x1DE5, 0x1DE6, 0x1DE7, 0x1DE8, 0x1DE9, 0x1DEA, 0x1DEB, - 0x1DEC, 0x1DED, 0x1DEE, 0x1DEF, 0x1DF0, 0x1DF1, 0x1DF2, 0x1DF3, 0x1DF4, 0x1DF5, 0x1DF6, 0x1DF7, 0x1DF8, - 0x1DF9, 0x1DFA, 0x1DFB, 0x1DFC, 0x1DFD, 0x1DFE, 0x1DFF, 0x20D0, 0x20D1, 0x20D2, 0x20D3, 0x20D4, 0x20D5, - 0x20D6, 0x20D7, 0x20D8, 0x20D9, 0x20DA, 0x20DB, 0x20DC, 0x20E1, 0x20E5, 0x20E6, 0x20E7, 0x20E8, 0x20E9, - 0x20EA, 0x20EB, 0x20EC, 0x20ED, 0x20EE, 0x20EF, 0x20F0, 0x2CEF, 0x2CF0, 0x2CF1, 0x2D7F, 0x2DE0, 0x2DE1, - 0x2DE2, 0x2DE3, 0x2DE4, 0x2DE5, 0x2DE6, 0x2DE7, 0x2DE8, 0x2DE9, 0x2DEA, 0x2DEB, 0x2DEC, 0x2DED, 0x2DEE, - 0x2DEF, 0x2DF0, 0x2DF1, 0x2DF2, 0x2DF3, 0x2DF4, 0x2DF5, 0x2DF6, 0x2DF7, 0x2DF8, 0x2DF9, 0x2DFA, 0x2DFB, - 0x2DFC, 0x2DFD, 0x2DFE, 0x2DFF, 0x302A, 0x302B, 0x302C, 0x302D, 0x3099, 0x309A, 0xA66F, 0xA674, 0xA675, - 0xA676, 0xA677, 0xA678, 0xA679, 0xA67A, 0xA67B, 0xA67C, 0xA67D, 0xA69E, 0xA69F, 0xA6F0, 0xA6F1, 0xA802, - 0xA806, 0xA80B, 0xA825, 0xA826, 0xA82C, 0xA8C4, 0xA8C5, 0xA8E0, 0xA8E1, 0xA8E2, 0xA8E3, 0xA8E4, 0xA8E5, - 0xA8E6, 0xA8E7, 0xA8E8, 0xA8E9, 0xA8EA, 0xA8EB, 0xA8EC, 0xA8ED, 0xA8EE, 0xA8EF, 0xA8F0, 0xA8F1, 0xA8FF, - 0xA926, 0xA927, 0xA928, 0xA929, 0xA92A, 0xA92B, 0xA92C, 0xA92D, 0xA947, 0xA948, 0xA949, 0xA94A, 0xA94B, - 0xA94C, 0xA94D, 0xA94E, 0xA94F, 0xA950, 0xA951, 0xA980, 0xA981, 0xA982, 0xA9B3, 0xA9B6, 0xA9B7, 0xA9B8, - 0xA9B9, 0xA9BC, 0xA9BD, 0xA9E5, 0xAA29, 0xAA2A, 0xAA2B, 0xAA2C, 0xAA2D, 0xAA2E, 0xAA31, 0xAA32, 0xAA35, - 0xAA36, 0xAA43, 0xAA4C, 0xAA7C, 0xAAB0, 0xAAB2, 0xAAB3, 0xAAB4, 0xAAB7, 0xAAB8, 0xAABE, 0xAABF, 0xAAC1, - 0xAAEC, 0xAAED, 0xAAF6, 0xABE5, 0xABE8, 0xABED, 0xFB1E, 0xFE00, 0xFE01, 0xFE02, 0xFE03, 0xFE04, 0xFE05, - 0xFE06, 0xFE07, 0xFE08, 0xFE09, 0xFE0A, 0xFE0B, 0xFE0C, 0xFE0D, 0xFE0E, 0xFE0F, 0xFE20, 0xFE21, 0xFE22, - 0xFE23, 0xFE24, 0xFE25, 0xFE26, 0xFE27, 0xFE28, 0xFE29, 0xFE2A, 0xFE2B, 0xFE2C, 0xFE2D, 0xFE2E, 0xFE2F, - 0x101FD, 0x102E0, 0x10376, 0x10377, 0x10378, 0x10379, 0x1037A, 0x10A01, 0x10A02, 0x10A03, 0x10A05, 0x10A06, 0x10A0C, - 0x10A0D, 0x10A0E, 0x10A0F, 0x10A38, 0x10A39, 0x10A3A, 0x10A3F, 0x10AE5, 0x10AE6, 0x10D24, 0x10D25, 0x10D26, 0x10D27, - 0x10D69, 0x10D6A, 0x10D6B, 0x10D6C, 0x10D6D, 0x10EAB, 0x10EAC, 0x10EFC, 0x10EFD, 0x10EFE, 0x10EFF, 0x10F46, 0x10F47, - 0x10F48, 0x10F49, 0x10F4A, 0x10F4B, 0x10F4C, 0x10F4D, 0x10F4E, 0x10F4F, 0x10F50, 0x10F82, 0x10F83, 0x10F84, 0x10F85, - 0x11001, 0x11038, 0x11039, 0x1103A, 0x1103B, 0x1103C, 0x1103D, 0x1103E, 0x1103F, 0x11040, 0x11041, 0x11042, 0x11043, - 0x11044, 0x11045, 0x11046, 0x11070, 0x11073, 0x11074, 0x1107F, 0x11080, 0x11081, 0x110B3, 0x110B4, 0x110B5, 0x110B6, - 0x110B9, 0x110BA, 0x110C2, 0x11100, 0x11101, 0x11102, 0x11127, 0x11128, 0x11129, 0x1112A, 0x1112B, 0x1112D, 0x1112E, - 0x1112F, 0x11130, 0x11131, 0x11132, 0x11133, 0x11134, 0x11173, 0x11180, 0x11181, 0x111B6, 0x111B7, 0x111B8, 0x111B9, - 0x111BA, 0x111BB, 0x111BC, 0x111BD, 0x111BE, 0x111C9, 0x111CA, 0x111CB, 0x111CC, 0x111CF, 0x1122F, 0x11230, 0x11231, - 0x11234, 0x11236, 0x11237, 0x1123E, 0x11241, 0x112DF, 0x112E3, 0x112E4, 0x112E5, 0x112E6, 0x112E7, 0x112E8, 0x112E9, - 0x112EA, 0x11300, 0x11301, 0x1133B, 0x1133C, 0x11340, 0x11366, 0x11367, 0x11368, 0x11369, 0x1136A, 0x1136B, 0x1136C, - 0x11370, 0x11371, 0x11372, 0x11373, 0x11374, 0x113BB, 0x113BC, 0x113BD, 0x113BE, 0x113BF, 0x113C0, 0x113CE, 0x113D0, - 0x113D2, 0x113E1, 0x113E2, 0x11438, 0x11439, 0x1143A, 0x1143B, 0x1143C, 0x1143D, 0x1143E, 0x1143F, 0x11442, 0x11443, - 0x11444, 0x11446, 0x1145E, 0x114B3, 0x114B4, 0x114B5, 0x114B6, 0x114B7, 0x114B8, 0x114BA, 0x114BF, 0x114C0, 0x114C2, - 0x114C3, 0x115B2, 0x115B3, 0x115B4, 0x115B5, 0x115BC, 0x115BD, 0x115BF, 0x115C0, 0x115DC, 0x115DD, 0x11633, 0x11634, - 0x11635, 0x11636, 0x11637, 0x11638, 0x11639, 0x1163A, 0x1163D, 0x1163F, 0x11640, 0x116AB, 0x116AD, 0x116B0, 0x116B1, - 0x116B2, 0x116B3, 0x116B4, 0x116B5, 0x116B7, 0x1171D, 0x1171F, 0x11722, 0x11723, 0x11724, 0x11725, 0x11727, 0x11728, - 0x11729, 0x1172A, 0x1172B, 0x1182F, 0x11830, 0x11831, 0x11832, 0x11833, 0x11834, 0x11835, 0x11836, 0x11837, 0x11839, - 0x1183A, 0x1193B, 0x1193C, 0x1193E, 0x11943, 0x119D4, 0x119D5, 0x119D6, 0x119D7, 0x119DA, 0x119DB, 0x119E0, 0x11A01, - 0x11A02, 0x11A03, 0x11A04, 0x11A05, 0x11A06, 0x11A07, 0x11A08, 0x11A09, 0x11A0A, 0x11A33, 0x11A34, 0x11A35, 0x11A36, - 0x11A37, 0x11A38, 0x11A3B, 0x11A3C, 0x11A3D, 0x11A3E, 0x11A47, 0x11A51, 0x11A52, 0x11A53, 0x11A54, 0x11A55, 0x11A56, - 0x11A59, 0x11A5A, 0x11A5B, 0x11A8A, 0x11A8B, 0x11A8C, 0x11A8D, 0x11A8E, 0x11A8F, 0x11A90, 0x11A91, 0x11A92, 0x11A93, - 0x11A94, 0x11A95, 0x11A96, 0x11A98, 0x11A99, 0x11C30, 0x11C31, 0x11C32, 0x11C33, 0x11C34, 0x11C35, 0x11C36, 0x11C38, - 0x11C39, 0x11C3A, 0x11C3B, 0x11C3C, 0x11C3D, 0x11C3F, 0x11C92, 0x11C93, 0x11C94, 0x11C95, 0x11C96, 0x11C97, 0x11C98, - 0x11C99, 0x11C9A, 0x11C9B, 0x11C9C, 0x11C9D, 0x11C9E, 0x11C9F, 0x11CA0, 0x11CA1, 0x11CA2, 0x11CA3, 0x11CA4, 0x11CA5, - 0x11CA6, 0x11CA7, 0x11CAA, 0x11CAB, 0x11CAC, 0x11CAD, 0x11CAE, 0x11CAF, 0x11CB0, 0x11CB2, 0x11CB3, 0x11CB5, 0x11CB6, - 0x11D31, 0x11D32, 0x11D33, 0x11D34, 0x11D35, 0x11D36, 0x11D3A, 0x11D3C, 0x11D3D, 0x11D3F, 0x11D40, 0x11D41, 0x11D42, - 0x11D43, 0x11D44, 0x11D45, 0x11D47, 0x11D90, 0x11D91, 0x11D95, 0x11D97, 0x11EF3, 0x11EF4, 0x11F00, 0x11F01, 0x11F36, - 0x11F37, 0x11F38, 0x11F39, 0x11F3A, 0x11F40, 0x11F42, 0x11F5A, 0x13440, 0x13447, 0x13448, 0x13449, 0x1344A, 0x1344B, - 0x1344C, 0x1344D, 0x1344E, 0x1344F, 0x13450, 0x13451, 0x13452, 0x13453, 0x13454, 0x13455, 0x1611E, 0x1611F, 0x16120, - 0x16121, 0x16122, 0x16123, 0x16124, 0x16125, 0x16126, 0x16127, 0x16128, 0x16129, 0x1612D, 0x1612E, 0x1612F, 0x16AF0, - 0x16AF1, 0x16AF2, 0x16AF3, 0x16AF4, 0x16B30, 0x16B31, 0x16B32, 0x16B33, 0x16B34, 0x16B35, 0x16B36, 0x16F4F, 0x16F8F, - 0x16F90, 0x16F91, 0x16F92, 0x16FE4, 0x1BC9D, 0x1BC9E, 0x1CF00, 0x1CF01, 0x1CF02, 0x1CF03, 0x1CF04, 0x1CF05, 0x1CF06, - 0x1CF07, 0x1CF08, 0x1CF09, 0x1CF0A, 0x1CF0B, 0x1CF0C, 0x1CF0D, 0x1CF0E, 0x1CF0F, 0x1CF10, 0x1CF11, 0x1CF12, 0x1CF13, - 0x1CF14, 0x1CF15, 0x1CF16, 0x1CF17, 0x1CF18, 0x1CF19, 0x1CF1A, 0x1CF1B, 0x1CF1C, 0x1CF1D, 0x1CF1E, 0x1CF1F, 0x1CF20, - 0x1CF21, 0x1CF22, 0x1CF23, 0x1CF24, 0x1CF25, 0x1CF26, 0x1CF27, 0x1CF28, 0x1CF29, 0x1CF2A, 0x1CF2B, 0x1CF2C, 0x1CF2D, - 0x1CF30, 0x1CF31, 0x1CF32, 0x1CF33, 0x1CF34, 0x1CF35, 0x1CF36, 0x1CF37, 0x1CF38, 0x1CF39, 0x1CF3A, 0x1CF3B, 0x1CF3C, - 0x1CF3D, 0x1CF3E, 0x1CF3F, 0x1CF40, 0x1CF41, 0x1CF42, 0x1CF43, 0x1CF44, 0x1CF45, 0x1CF46, 0x1D167, 0x1D168, 0x1D169, - 0x1D17B, 0x1D17C, 0x1D17D, 0x1D17E, 0x1D17F, 0x1D180, 0x1D181, 0x1D182, 0x1D185, 0x1D186, 0x1D187, 0x1D188, 0x1D189, - 0x1D18A, 0x1D18B, 0x1D1AA, 0x1D1AB, 0x1D1AC, 0x1D1AD, 0x1D242, 0x1D243, 0x1D244, 0x1DA00, 0x1DA01, 0x1DA02, 0x1DA03, - 0x1DA04, 0x1DA05, 0x1DA06, 0x1DA07, 0x1DA08, 0x1DA09, 0x1DA0A, 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, - 0x1DA11, 0x1DA12, 0x1DA13, 0x1DA14, 0x1DA15, 0x1DA16, 0x1DA17, 0x1DA18, 0x1DA19, 0x1DA1A, 0x1DA1B, 0x1DA1C, 0x1DA1D, - 0x1DA1E, 0x1DA1F, 0x1DA20, 0x1DA21, 0x1DA22, 0x1DA23, 0x1DA24, 0x1DA25, 0x1DA26, 0x1DA27, 0x1DA28, 0x1DA29, 0x1DA2A, - 0x1DA2B, 0x1DA2C, 0x1DA2D, 0x1DA2E, 0x1DA2F, 0x1DA30, 0x1DA31, 0x1DA32, 0x1DA33, 0x1DA34, 0x1DA35, 0x1DA36, 0x1DA3B, - 0x1DA3C, 0x1DA3D, 0x1DA3E, 0x1DA3F, 0x1DA40, 0x1DA41, 0x1DA42, 0x1DA43, 0x1DA44, 0x1DA45, 0x1DA46, 0x1DA47, 0x1DA48, - 0x1DA49, 0x1DA4A, 0x1DA4B, 0x1DA4C, 0x1DA4D, 0x1DA4E, 0x1DA4F, 0x1DA50, 0x1DA51, 0x1DA52, 0x1DA53, 0x1DA54, 0x1DA55, - 0x1DA56, 0x1DA57, 0x1DA58, 0x1DA59, 0x1DA5A, 0x1DA5B, 0x1DA5C, 0x1DA5D, 0x1DA5E, 0x1DA5F, 0x1DA60, 0x1DA61, 0x1DA62, - 0x1DA63, 0x1DA64, 0x1DA65, 0x1DA66, 0x1DA67, 0x1DA68, 0x1DA69, 0x1DA6A, 0x1DA6B, 0x1DA6C, 0x1DA75, 0x1DA84, 0x1DA9B, - 0x1DA9C, 0x1DA9D, 0x1DA9E, 0x1DA9F, 0x1DAA1, 0x1DAA2, 0x1DAA3, 0x1DAA4, 0x1DAA5, 0x1DAA6, 0x1DAA7, 0x1DAA8, 0x1DAA9, - 0x1DAAA, 0x1DAAB, 0x1DAAC, 0x1DAAD, 0x1DAAE, 0x1DAAF, 0x1E000, 0x1E001, 0x1E002, 0x1E003, 0x1E004, 0x1E005, 0x1E006, - 0x1E008, 0x1E009, 0x1E00A, 0x1E00B, 0x1E00C, 0x1E00D, 0x1E00E, 0x1E00F, 0x1E010, 0x1E011, 0x1E012, 0x1E013, 0x1E014, - 0x1E015, 0x1E016, 0x1E017, 0x1E018, 0x1E01B, 0x1E01C, 0x1E01D, 0x1E01E, 0x1E01F, 0x1E020, 0x1E021, 0x1E023, 0x1E024, - 0x1E026, 0x1E027, 0x1E028, 0x1E029, 0x1E02A, 0x1E08F, 0x1E130, 0x1E131, 0x1E132, 0x1E133, 0x1E134, 0x1E135, 0x1E136, - 0x1E2AE, 0x1E2EC, 0x1E2ED, 0x1E2EE, 0x1E2EF, 0x1E4EC, 0x1E4ED, 0x1E4EE, 0x1E4EF, 0x1E5EE, 0x1E5EF, 0x1E8D0, 0x1E8D1, - 0x1E8D2, 0x1E8D3, 0x1E8D4, 0x1E8D5, 0x1E8D6, 0x1E944, 0x1E945, 0x1E946, 0x1E947, 0x1E948, 0x1E949, 0x1E94A, 0xE0100, - 0xE0101, 0xE0102, 0xE0103, 0xE0104, 0xE0105, 0xE0106, 0xE0107, 0xE0108, 0xE0109, 0xE010A, 0xE010B, 0xE010C, 0xE010D, - 0xE010E, 0xE010F, 0xE0110, 0xE0111, 0xE0112, 0xE0113, 0xE0114, 0xE0115, 0xE0116, 0xE0117, 0xE0118, 0xE0119, 0xE011A, - 0xE011B, 0xE011C, 0xE011D, 0xE011E, 0xE011F, 0xE0120, 0xE0121, 0xE0122, 0xE0123, 0xE0124, 0xE0125, 0xE0126, 0xE0127, - 0xE0128, 0xE0129, 0xE012A, 0xE012B, 0xE012C, 0xE012D, 0xE012E, 0xE012F, 0xE0130, 0xE0131, 0xE0132, 0xE0133, 0xE0134, - 0xE0135, 0xE0136, 0xE0137, 0xE0138, 0xE0139, 0xE013A, 0xE013B, 0xE013C, 0xE013D, 0xE013E, 0xE013F, 0xE0140, 0xE0141, - 0xE0142, 0xE0143, 0xE0144, 0xE0145, 0xE0146, 0xE0147, 0xE0148, 0xE0149, 0xE014A, 0xE014B, 0xE014C, 0xE014D, 0xE014E, - 0xE014F, 0xE0150, 0xE0151, 0xE0152, 0xE0153, 0xE0154, 0xE0155, 0xE0156, 0xE0157, 0xE0158, 0xE0159, 0xE015A, 0xE015B, - 0xE015C, 0xE015D, 0xE015E, 0xE015F, 0xE0160, 0xE0161, 0xE0162, 0xE0163, 0xE0164, 0xE0165, 0xE0166, 0xE0167, 0xE0168, - 0xE0169, 0xE016A, 0xE016B, 0xE016C, 0xE016D, 0xE016E, 0xE016F, 0xE0170, 0xE0171, 0xE0172, 0xE0173, 0xE0174, 0xE0175, - 0xE0176, 0xE0177, 0xE0178, 0xE0179, 0xE017A, 0xE017B, 0xE017C, 0xE017D, 0xE017E, 0xE017F, 0xE0180, 0xE0181, 0xE0182, - 0xE0183, 0xE0184, 0xE0185, 0xE0186, 0xE0187, 0xE0188, 0xE0189, 0xE018A, 0xE018B, 0xE018C, 0xE018D, 0xE018E, 0xE018F, - 0xE0190, 0xE0191, 0xE0192, 0xE0193, 0xE0194, 0xE0195, 0xE0196, 0xE0197, 0xE0198, 0xE0199, 0xE019A, 0xE019B, 0xE019C, - 0xE019D, 0xE019E, 0xE019F, 0xE01A0, 0xE01A1, 0xE01A2, 0xE01A3, 0xE01A4, 0xE01A5, 0xE01A6, 0xE01A7, 0xE01A8, 0xE01A9, - 0xE01AA, 0xE01AB, 0xE01AC, 0xE01AD, 0xE01AE, 0xE01AF, 0xE01B0, 0xE01B1, 0xE01B2, 0xE01B3, 0xE01B4, 0xE01B5, 0xE01B6, - 0xE01B7, 0xE01B8, 0xE01B9, 0xE01BA, 0xE01BB, 0xE01BC, 0xE01BD, 0xE01BE, 0xE01BF, 0xE01C0, 0xE01C1, 0xE01C2, 0xE01C3, - 0xE01C4, 0xE01C5, 0xE01C6, 0xE01C7, 0xE01C8, 0xE01C9, 0xE01CA, 0xE01CB, 0xE01CC, 0xE01CD, 0xE01CE, 0xE01CF, 0xE01D0, - 0xE01D1, 0xE01D2, 0xE01D3, 0xE01D4, 0xE01D5, 0xE01D6, 0xE01D7, 0xE01D8, 0xE01D9, 0xE01DA, 0xE01DB, 0xE01DC, 0xE01DD, - 0xE01DE, 0xE01DF, 0xE01E0, 0xE01E1, 0xE01E2, 0xE01E3, 0xE01E4, 0xE01E5, 0xE01E6, 0xE01E7, 0xE01E8, 0xE01E9, 0xE01EA, - 0xE01EB, 0xE01EC, 0xE01ED, 0xE01EE, 0xE01EF - /* END: COMBINING CHAR TABLE */ -}; - -static const unsigned long combiningCharTableSize = sizeof(combiningCharTable) / sizeof(combiningCharTable[0]); - -static bool isCombiningChar(unsigned long cp) { - for (size_t i = 0; i < combiningCharTableSize; i++) { - auto code = combiningCharTable[i]; - if (code > cp) { - return false; - } - if (code == cp) { - return true; - } - } - return false; -} - -/* Get length of previous grapheme */ -static size_t defaultPrevCharLen(const char * buf, size_t /*buf_len*/, size_t pos, size_t * col_len) { - size_t end = pos; - while (pos > 0) { - size_t len = prevUtf8CodePointLen(buf, pos); - pos -= len; - int cp; - utf8BytesToCodePoint(buf + pos, len, &cp); - if (!isCombiningChar(cp)) { - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - return end - pos; - } - } - /* NOTREACHED */ - return 0; -} - -/* Get length of next grapheme */ -static size_t defaultNextCharLen(const char * buf, size_t buf_len, size_t pos, size_t * col_len) { - size_t beg = pos; - int cp; - size_t len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (isCombiningChar(cp)) { - /* NOTREACHED */ - return 0; - } - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - pos += len; - while (pos < buf_len) { - int cp; - len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (!isCombiningChar(cp)) { - return pos - beg; - } - pos += len; - } - return pos - beg; -} - -/* Read a Unicode from file. */ -static size_t defaultReadCode(int fd, char * buf, size_t buf_len, int * cp) { - if (buf_len < 1) { - return -1; - } - size_t nread = read(fd, &buf[0], 1); - if (nread <= 0) { - return nread; - } - - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - ; - } else if ((byte & 0xE0) == 0xC0) { - if (buf_len < 2) { - return -1; - } - nread = read(fd, &buf[1], 1); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF0) == 0xE0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 2); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF8) == 0xF0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 3); - if (nread <= 0) { - return nread; - } - } else { - return -1; - } - - return utf8BytesToCodePoint(buf, buf_len, cp); -} - -/* Set default encoding functions */ -static linenoisePrevCharLen * prevCharLen = defaultPrevCharLen; -static linenoiseNextCharLen * nextCharLen = defaultNextCharLen; -static linenoiseReadCode * readCode = defaultReadCode; - -/* Set used defined encoding functions */ -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc) { - prevCharLen = prevCharLenFunc; - nextCharLen = nextCharLenFunc; - readCode = readCodeFunc; -} - -/* ======================= Low level terminal handling ====================== */ - -/* Enable "mask mode". When it is enabled, instead of the input that - * the user is typing, the terminal will just display a corresponding - * number of asterisks, like "****". This is useful for passwords and other - * secrets that should not be displayed. */ -void linenoiseMaskModeEnable(void) { - maskmode = 1; -} - -/* Disable mask mode. */ -void linenoiseMaskModeDisable(void) { - maskmode = 0; -} - -/* Set if to use or not the multi line mode. */ -void linenoiseSetMultiLine(int ml) { - mlmode = ml; -} - -/* Return true if the terminal name is in the list of terminals we know are - * not able to understand basic escape sequences. */ -static int isUnsupportedTerm(void) { - char *term = getenv("TERM"); - if (term == NULL) return 0; - for (size_t j = 0; j < unsupported_term.size(); ++j) { - if (!strcasecmp(term, unsupported_term[j])) { - return 1; - } - } - return 0; -} - -/* Raw mode: 1960 magic shit. */ -static int enableRawMode(int fd) { - struct termios raw; - - if (!isatty(STDIN_FILENO)) goto fatal; - if (!atexit_registered) { - atexit(linenoiseAtExit); - atexit_registered = 1; - } - if (tcgetattr(fd,&orig_termios) == -1) goto fatal; - - raw = orig_termios; /* modify the original mode */ - /* input modes: no break, no CR to NL, no parity check, no strip char, - * no start/stop output control. */ - raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); - /* output modes - disable post processing */ - raw.c_oflag &= ~(OPOST); - /* control modes - set 8 bit chars */ - raw.c_cflag |= (CS8); - /* local modes - choing off, canonical off, no extended functions, - * no signal chars (^Z,^C) */ - raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); - /* control chars - set return condition: min number of bytes and timer. - * We want read to return every single byte, without timeout. */ - raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ - - /* put terminal in raw mode after flushing */ - if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; - rawmode = 1; - return 0; - -fatal: - errno = ENOTTY; - return -1; -} - -static void disableRawMode(int fd) { - /* Don't even check the return value as it's too late. */ - if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) - rawmode = 0; -} - -/* Use the ESC [6n escape sequence to query the horizontal cursor position - * and return it. On error -1 is returned, on success the position of the - * cursor. */ -static int getCursorPosition(int ifd, int ofd) { - char buf[32]; - int cols, rows; - unsigned int i = 0; - - /* Report cursor location */ - if (write(ofd, "\x1b[6n", 4) != 4) return -1; - - /* Read the response: ESC [ rows ; cols R */ - while (i < sizeof(buf)-1) { - if (read(ifd,buf+i,1) != 1) break; - if (buf[i] == 'R') break; - i++; - } - buf[i] = '\0'; - - /* Parse it. */ - if (buf[0] != ESC || buf[1] != '[') return -1; - if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; - return cols; -} - -/* Try to get the number of columns in the current terminal, or assume 80 - * if it fails. */ -static int getColumns(int ifd, int ofd) { - struct winsize ws; - - if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { - /* ioctl() failed. Try to query the terminal itself. */ - int start, cols; - - /* Get the initial position so we can restore it later. */ - start = getCursorPosition(ifd,ofd); - if (start == -1) goto failed; - - /* Go to right margin and get position. */ - if (write(ofd,"\x1b[999C",6) != 6) goto failed; - cols = getCursorPosition(ifd,ofd); - if (cols == -1) goto failed; - - /* Restore position. */ - if (cols > start) { - char seq[32]; - snprintf(seq,32,"\x1b[%dD",cols-start); - if (write(ofd,seq,strlen(seq)) == -1) { - /* Can't recover... */ - } - } - return cols; - } else { - return ws.ws_col; - } - -failed: - return 80; -} - -/* Clear the screen. Used to handle ctrl+l */ -void linenoiseClearScreen(void) { - if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { - /* nothing to do, just to avoid warning. */ - } -} - -/* Beep, used for completion when there is nothing to complete or when all - * the choices were already shown. */ -static void linenoiseBeep(void) { - fprintf(stderr, "\x7"); - fflush(stderr); -} - -/* Called by completeLine() and linenoiseShow() to render the current - * edited line with the proposed completion. If the current completion table - * is already available, it is passed as second argument, otherwise the - * function will use the callback to obtain it. - * - * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { - /* Obtain the table of completions if the caller didn't provide one. */ - linenoiseCompletions ctable; - if (lc == NULL) { - completionCallback(ls->buf, &ctable); - lc = &ctable; - } - - /* Show the edited line with completion if possible, or just refresh. */ - if (ls->completion_idx < lc->len) { - struct linenoiseState saved = *ls; - ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); - ls->buf = lc->cvec[ls->completion_idx]; - refreshLineWithFlags(ls, flags); - ls->len = saved.len; - ls->pos = saved.pos; - ls->buf = saved.buf; - } else { - refreshLineWithFlags(ls, flags); - } - - if (lc == &ctable) { - ctable.to_free = false; - } -} - -enum ESC_TYPE { ESC_NULL = 0, ESC_DELETE, ESC_UP, ESC_DOWN, ESC_RIGHT, ESC_LEFT, ESC_HOME, ESC_END }; - -static ESC_TYPE readEscapeSequence(struct linenoiseState * l) { - /* Check if the file input has additional data. */ - struct pollfd pfd; - pfd.fd = l->ifd; - pfd.events = POLLIN; - - auto ret = poll(&pfd, 1, 1); // 1 millisecond timeout - if (ret <= 0) { // -1: error, 0: timeout - return ESC_NULL; - } - - /* Read the next two bytes representing the escape sequence. - * Use two calls to handle slow terminals returning the two - * chars at different times. */ - char seq[3]; - if (read(l->ifd, seq, 1) == -1) { - return ESC_NULL; - } - if (read(l->ifd, seq + 1, 1) == -1) { - return ESC_NULL; - } - - /* ESC [ sequences. */ - if (seq[0] == '[') { - if (seq[1] >= '0' && seq[1] <= '9') { - /* Extended escape, read additional byte. */ - if (read(l->ifd, seq + 2, 1) == -1) { - return ESC_NULL; - } - if (seq[2] == '~') { - switch (seq[1]) { - case '3': - return ESC_DELETE; - } - } - } else { - switch (seq[1]) { - case 'A': - return ESC_UP; - case 'B': - return ESC_DOWN; - case 'C': - return ESC_RIGHT; - case 'D': - return ESC_LEFT; - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - } - - /* ESC O sequences. */ - else if (seq[0] == 'O') { - switch (seq[1]) { - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - return ESC_NULL; -} - -/* This is an helper function for linenoiseEdit*() and is called when the - * user types the key in order to complete the string currently in the - * input. - * - * The state of the editing is encapsulated into the pointed linenoiseState - * structure as described in the structure definition. - * - * If the function returns non-zero, the caller should handle the - * returned value as a byte read from the standard input, and process - * it as usually: this basically means that the function may return a byte - * read from the terminal but not processed. Otherwise, if zero is returned, - * the input was consumed by the completeLine() function to navigate the - * possible completions, and the caller should read for the next characters - * from stdin. */ -static int completeLine(struct linenoiseState * ls, int keypressed, ESC_TYPE esc_type) { - linenoiseCompletions lc; - int nwritten; - char c = keypressed; - - completionCallback(ls->buf, &lc); - if (lc.len == 0) { - linenoiseBeep(); - ls->in_completion = 0; - } else { - if (c == TAB) { - if (ls->in_completion == 0) { - ls->in_completion = 1; - ls->completion_idx = 0; - } else { - ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1); - if (ls->completion_idx == lc.len) { - linenoiseBeep(); - } - } - c = 0; - } else if (c == ESC && esc_type == ESC_NULL) { - /* Re-show original buffer */ - if (ls->completion_idx < lc.len) { - refreshLine(ls); - } - ls->in_completion = 0; - c = 0; - } else { - /* Update buffer and return */ - if (ls->completion_idx < lc.len) { - nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]); - ls->len = ls->pos = nwritten; - } - ls->in_completion = 0; - } - - /* Show completion or original buffer */ - if (ls->in_completion && ls->completion_idx < lc.len) { - refreshLineWithCompletion(ls, &lc, REFRESH_ALL); - } else { - refreshLine(ls); - } - } - - return c; /* Return last read character */ -} - -/* Register a callback function to be called for tab-completion. */ -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { - completionCallback = fn; -} - -/* Register a hits function to be called to show hits to the user at the - * right of the prompt. */ -void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { - hintsCallback = fn; -} - -/* Register a function to free the hints returned by the hints callback - * registered with linenoiseSetHintsCallback(). */ -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { - freeHintsCallback = fn; -} - -/* This function is used by the callback function registered by the user - * in order to add completion options given the input string when the - * user typed . See the example.c source code for a very easy to - * understand example. */ -void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { - const size_t len = strlen(str); - auto copy = std::make_unique(len + 1); - if (!copy) { - return; - } - - memcpy(copy.get(), str, len + 1); - char ** cvec = static_cast(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1))); - if (cvec == nullptr) { - return; - } - - lc->cvec = cvec; - lc->cvec[lc->len++] = copy.release(); -} - -/* Get column length from begining of buffer to current byte position */ -static size_t columnPos(const char * buf, size_t buf_len, size_t pos) { - size_t ret = 0; - size_t off = 0; - while (off < pos) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - off += len; - ret += col_len; - } - return ret; -} - -/* Helper of refreshSingleLine() and refreshMultiLine() to show hints - * to the right of the prompt. */ -static void refreshShowHints(std::string & ab, struct linenoiseState * l, int pcollen) { - char seq[64]; - size_t collen = pcollen + columnPos(l->buf, l->len, l->len); - if (hintsCallback && collen < l->cols) { - int color = -1, bold = 0; - const char *hint = hintsCallback(l->buf,&color,&bold); - if (hint) { - int hintlen = strlen(hint); - int hintmaxlen = l->cols - collen; - if (hintlen > hintmaxlen) hintlen = hintmaxlen; - if (bold == 1 && color == -1) color = 37; - if (color != -1 || bold != 0) - snprintf(seq,64,"\033[%d;%d;49m",bold,color); - else - seq[0] = '\0'; - ab.append(seq); - ab.append(hint, hintlen); - if (color != -1 || bold != 0) - ab.append("\033[0m"); - - /* Call the function to free the hint returned. */ - if (freeHintsCallback) freeHintsCallback(hint); - } - } -} - -/* Check if text is an ANSI escape sequence */ -static int isAnsiEscape(const char * buf, size_t buf_len, size_t * len) { - if (buf_len > 2 && !memcmp("\033[", buf, 2)) { - size_t off = 2; - while (off < buf_len) { - switch (buf[off++]) { - case 'A': - case 'B': - case 'C': - case 'D': - case 'E': - case 'F': - case 'G': - case 'H': - case 'J': - case 'K': - case 'S': - case 'T': - case 'f': - case 'm': - *len = off; - return 1; - } - } - } - return 0; -} - -/* Get column length of prompt text */ -static size_t promptTextColumnLen(const char * prompt, size_t plen) { - char buf[LINENOISE_MAX_LINE]; - size_t buf_len = 0; - size_t off = 0; - while (off < plen) { - size_t len; - if (isAnsiEscape(prompt + off, plen - off, &len)) { - off += len; - continue; - } - buf[buf_len++] = prompt[off++]; - } - return columnPos(buf, buf_len, buf_len); -} - -/* Single line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshSingleLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int fd = l->ofd; - char *buf = l->buf; - size_t len = l->len; - size_t pos = l->pos; - std::string ab; - - while ((pcollen + columnPos(buf, len, pos)) >= l->cols) { - int chlen = nextCharLen(buf, len, 0, NULL); - buf += chlen; - len -= chlen; - pos -= chlen; - } - while (pcollen + columnPos(buf, len, len) > l->cols) { - len -= prevCharLen(buf, len, len, NULL); - } - - /* Cursor to left edge */ - snprintf(seq,sizeof(seq),"\r"); - ab.append(seq); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - while (len--) { - ab.append("*"); - } - } else { - ab.append(buf, len); - } - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - } - - /* Erase to right */ - snprintf(seq,sizeof(seq),"\x1b[0K"); - ab.append(seq); - if (flags & REFRESH_WRITE) { - /* Move cursor to original position. */ - snprintf(seq, sizeof(seq), "\r\x1b[%dC", (int) (columnPos(buf, len, pos) + pcollen)); - ab.append(seq); - } - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Get column length from begining of buffer to current byte position for multiline mode*/ -static size_t columnPosForMultiLine(const char * buf, size_t buf_len, size_t pos, size_t cols, size_t ini_pos) { - size_t ret = 0; - size_t colwid = ini_pos; - - size_t off = 0; - while (off < buf_len) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - - int dif = (int) (colwid + col_len) - (int) cols; - if (dif > 0) { - ret += dif; - colwid = col_len; - } else if (dif == 0) { - colwid = 0; - } else { - colwid += col_len; - } - - if (off >= pos) { - break; - } - off += len; - ret += col_len; - } - - return ret; -} - -/* Multi line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshMultiLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int colpos = columnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcollen); - int colpos2; /* cursor column position. */ - int rows = (pcollen + colpos + l->cols - 1) / l->cols; /* rows used by current buf. */ - int rpos = (pcollen + l->oldcolpos + l->cols) / l->cols; /* cursor relative row. */ - int rpos2; /* rpos after refresh. */ - int col; /* column position, zero-based. */ - int old_rows = l->oldrows; - int fd = l->ofd, j; - std::string ab; - l->oldrows = rows; - - /* First step: clear all the lines used before. To do so start by - * going to the last row. */ - if (flags & REFRESH_CLEAN) { - if (old_rows - rpos > 0) { - snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - ab.append(seq); - } - - /* Now for every row clear it, go up. */ - for (j = 0; j < old_rows - 1; j++) { - snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - ab.append(seq); - } - } - - if (flags & REFRESH_ALL) { - /* Clean the top line. */ - snprintf(seq,64,"\r\x1b[0K"); - ab.append(seq); - } - - /* Get column length to cursor position */ - colpos2 = columnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcollen); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - for (unsigned int i = 0; i < l->len; ++i) { - ab.append("*"); - } - } else { - ab.append(l->buf, l->len); - } - - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - - /* If we are at the very end of the screen with our prompt, we need to - * emit a newline and move the prompt to the first column. */ - if (l->pos && l->pos == l->len && (colpos2 + pcollen) % l->cols == 0) { - ab.append("\n"); - snprintf(seq,64,"\r"); - ab.append(seq); - rows++; - if (rows > (int)l->oldrows) l->oldrows = rows; - } - - /* Move cursor to right position. */ - rpos2 = (pcollen + colpos2 + l->cols) / l->cols; /* Current cursor relative row */ - - /* Go up till we reach the expected position. */ - if (rows - rpos2 > 0) { - snprintf(seq,64,"\x1b[%dA", rows-rpos2); - ab.append(seq); - } - - /* Set column. */ - col = (pcollen + colpos2) % l->cols; - if (col) - snprintf(seq,64,"\r\x1b[%dC", col); - else - snprintf(seq,64,"\r"); - ab.append(seq); - } - - l->oldcolpos = colpos2; - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Calls the two low level functions refreshSingleLine() or - * refreshMultiLine() according to the selected mode. */ -static void refreshLineWithFlags(struct linenoiseState *l, int flags) { - if (mlmode) - refreshMultiLine(l,flags); - else - refreshSingleLine(l,flags); -} - -/* Utility function to avoid specifying REFRESH_ALL all the times. */ -static void refreshLine(struct linenoiseState *l) { - refreshLineWithFlags(l,REFRESH_ALL); -} - -/* Hide the current line, when using the multiplexing API. */ -void linenoiseHide(struct linenoiseState *l) { - if (mlmode) - refreshMultiLine(l,REFRESH_CLEAN); - else - refreshSingleLine(l,REFRESH_CLEAN); -} - -/* Show the current line, when using the multiplexing API. */ -void linenoiseShow(struct linenoiseState *l) { - if (l->in_completion) { - refreshLineWithCompletion(l,NULL,REFRESH_WRITE); - } else { - refreshLineWithFlags(l,REFRESH_WRITE); - } -} - -/* Insert the character 'c' at cursor current position. - * - * On error writing to the terminal -1 is returned, otherwise 0. */ -static int linenoiseEditInsert(struct linenoiseState * l, const char * cbuf, int clen) { - if (l->len + clen <= l->buflen) { - if (l->len == l->pos) { - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - ; - l->buf[l->len] = '\0'; - if ((!mlmode && promptTextColumnLen(l->prompt, l->plen) + columnPos(l->buf, l->len, l->len) < l->cols && - !hintsCallback)) { - /* Avoid a full update of the line in the - * trivial case. */ - if (maskmode == 1) { - static const char d = '*'; - if (write(l->ofd, &d, 1) == -1) { - return -1; - } - } else { - if (write(l->ofd, cbuf, clen) == -1) { - return -1; - } - } - } else { - refreshLine(l); - } - } else { - memmove(l->buf + l->pos + clen, l->buf + l->pos, l->len - l->pos); - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - l->buf[l->len] = '\0'; - refreshLine(l); - } - } - return 0; -} - -/* Move cursor on the left. */ -static void linenoiseEditMoveLeft(struct linenoiseState * l) { - if (l->pos > 0) { - l->pos -= prevCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor on the right. */ -static void linenoiseEditMoveRight(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos += nextCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor to the start of the line. */ -static void linenoiseEditMoveHome(struct linenoiseState * l) { - if (l->pos != 0) { - l->pos = 0; - refreshLine(l); - } -} - -/* Move cursor to the end of the line. */ -static void linenoiseEditMoveEnd(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos = l->len; - refreshLine(l); - } -} - -/* Substitute the currently edited line with the next or previous history - * entry as specified by 'dir'. */ -#define LINENOISE_HISTORY_NEXT 0 -#define LINENOISE_HISTORY_PREV 1 - -static void linenoiseEditHistoryNext(struct linenoiseState * l, int dir) { - if (history_len > 1) { - /* Update the current history entry before to - * overwrite it with the next one. */ - free(history[history_len - 1 - l->history_index]); - history[history_len - 1 - l->history_index] = strdup(l->buf); - /* Show the new entry */ - l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; - if (l->history_index < 0) { - l->history_index = 0; - return; - } else if (l->history_index >= history_len) { - l->history_index = history_len-1; - return; - } - strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); - l->buf[l->buflen-1] = '\0'; - l->len = l->pos = strlen(l->buf); - refreshLine(l); - } -} - -/* Delete the character at the right of the cursor without altering the cursor - * position. Basically this is what happens with the "Delete" keyboard key. */ -static void linenoiseEditDelete(struct linenoiseState * l) { - if (l->len > 0 && l->pos < l->len) { - int chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos, l->buf + l->pos + chlen, l->len - l->pos - chlen); - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Backspace implementation. */ -static void linenoiseEditBackspace(struct linenoiseState * l) { - if (l->pos > 0 && l->len > 0) { - int chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos - chlen, l->buf + l->pos, l->len - l->pos); - l->pos -= chlen; - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Delete the previous word, maintaining the cursor at the start of the - * current word. */ -static void linenoiseEditDeletePrevWord(struct linenoiseState * l) { - size_t old_pos = l->pos; - size_t diff; - - while (l->pos > 0 && l->buf[l->pos-1] == ' ') - l->pos--; - while (l->pos > 0 && l->buf[l->pos-1] != ' ') - l->pos--; - diff = old_pos - l->pos; - memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); - l->len -= diff; - refreshLine(l); -} - -/* This function is part of the multiplexed API of Linenoise, that is used - * in order to implement the blocking variant of the API but can also be - * called by the user directly in an event driven program. It will: - * - * 1. Initialize the linenoise state passed by the user. - * 2. Put the terminal in RAW mode. - * 3. Show the prompt. - * 4. Return control to the user, that will have to call linenoiseEditFeed() - * each time there is some data arriving in the standard input. - * - * The user can also call linenoiseEditHide() and linenoiseEditShow() if it - * is required to show some input arriving asynchronously, without mixing - * it with the currently edited line. - * - * When linenoiseEditFeed() returns non-NULL, the user finished with the - * line editing session (pressed enter CTRL-D/C): in this case the caller - * needs to call linenoiseEditStop() to put back the terminal in normal - * mode. This will not destroy the buffer, as long as the linenoiseState - * is still valid in the context of the caller. - * - * The function returns 0 on success, or -1 if writing to standard output - * fails. If stdin_fd or stdout_fd are set to -1, the default is to use - * STDIN_FILENO and STDOUT_FILENO. - */ -int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { - /* Populate the linenoise state that we pass to functions implementing - * specific editing functionalities. */ - l->in_completion = 0; - l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; - l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; - l->buf = buf; - l->buflen = buflen; - l->prompt = prompt; - l->plen = strlen(prompt); - l->oldcolpos = l->pos = 0; - l->len = 0; - - /* Enter raw mode. */ - if (enableRawMode(l->ifd) == -1) return -1; - - l->cols = getColumns(stdin_fd, stdout_fd); - l->oldrows = 0; - l->history_index = 0; - - /* Buffer starts empty. */ - l->buf[0] = '\0'; - l->buflen--; /* Make sure there is always space for the nullterm */ - - /* If stdin is not a tty, stop here with the initialization. We - * will actually just read a line from standard input in blocking - * mode later, in linenoiseEditFeed(). */ - if (!isatty(l->ifd)) return 0; - - /* The latest history entry is always our current buffer, that - * initially is just an empty string. */ - linenoiseHistoryAdd(""); - - if (write(l->ofd,prompt,l->plen) == -1) return -1; - return 0; -} - -const char* linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; - -static const char * handleEnterKey(struct linenoiseState * l) { - --history_len; - free(history[history_len]); - if (mlmode) { - linenoiseEditMoveEnd(l); - } - if (hintsCallback) { - /* Force a refresh without hints to leave the previous - * line as the user typed it after a newline. */ - linenoiseHintsCallback * hc = hintsCallback; - hintsCallback = NULL; - refreshLine(l); - hintsCallback = hc; - } - - return strdup(l->buf); -} - -static const char * handleCtrlCKey() { - errno = EAGAIN; - return NULL; -} - -static const char * handleCtrlDKey(struct linenoiseState * l) { - if (l->len > 0) { - linenoiseEditDelete(l); - return linenoiseEditMore; - } - - --history_len; - free(history[history_len]); - errno = ENOENT; - return NULL; -} - -static void handleCtrlTKey(struct linenoiseState * l) { - if (l->pos > 0 && l->pos < l->len) { - auto prev_chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - auto curr_chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - - std::string prev_char(prev_chlen, 0); - memcpy(prev_char.data(), l->buf + l->pos - prev_chlen, prev_chlen); - memmove(l->buf + l->pos - prev_chlen, l->buf + l->pos, curr_chlen); - memmove(l->buf + l->pos - prev_chlen + curr_chlen, prev_char.data(), prev_chlen); - - l->pos = l->pos - prev_chlen + curr_chlen; - if (l->pos + prev_chlen != l->len) { - l->pos += prev_chlen; - } - - refreshLine(l); - } -} - -static void handleEscapeSequence(struct linenoiseState * l, int esc_type) { - switch (esc_type) { - case ESC_NULL: - break; - case ESC_DELETE: - linenoiseEditDelete(l); - break; - case ESC_UP: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case ESC_DOWN: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC_RIGHT: - linenoiseEditMoveRight(l); - break; - case ESC_LEFT: - linenoiseEditMoveLeft(l); - break; - case ESC_HOME: - linenoiseEditMoveHome(l); - break; - case ESC_END: - linenoiseEditMoveEnd(l); - break; - } -} - -static void handleCtrlUKey(struct linenoiseState * l) { - l->buf[0] = '\0'; - l->pos = l->len = 0; - refreshLine(l); -} - -static void handleCtrlKKey(struct linenoiseState * l) { - l->buf[l->pos] = '\0'; - l->len = l->pos; - refreshLine(l); -} - -static const char * processInputCharacter(struct linenoiseState * l, int c, char * cbuf, int nread, int esc_type) { - switch (c) { - case ENTER: - return handleEnterKey(l); - case CTRL_C: - return handleCtrlCKey(); - case BACKSPACE: - case CTRL_H: - linenoiseEditBackspace(l); - break; - case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the - line is empty, act as end-of-file. */ - return handleCtrlDKey(l); - case CTRL_T: - handleCtrlTKey(l); - break; - case CTRL_B: - linenoiseEditMoveLeft(l); - break; - case CTRL_F: - linenoiseEditMoveRight(l); - break; - case CTRL_P: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case CTRL_N: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC: - handleEscapeSequence(l, esc_type); - break; - default: - if (linenoiseEditInsert(l, cbuf, nread)) { - return NULL; - } - break; - case CTRL_U: /* Ctrl+u, delete the whole line. */ - handleCtrlUKey(l); - break; - case CTRL_K: /* Ctrl+k, delete from current to end of line. */ - handleCtrlKKey(l); - break; - case CTRL_A: /* Ctrl+a, go to the start of the line */ - linenoiseEditMoveHome(l); - break; - case CTRL_E: /* ctrl+e, go to the end of the line */ - linenoiseEditMoveEnd(l); - break; - case CTRL_L: /* ctrl+l, clear screen */ - linenoiseClearScreen(); - refreshLine(l); - break; - case CTRL_W: /* ctrl+w, delete previous word */ - linenoiseEditDeletePrevWord(l); - break; - } - return linenoiseEditMore; -} - -/* This function is part of the multiplexed API of linenoise, see the top - * comment on linenoiseEditStart() for more information. Call this function - * each time there is some data to read from the standard input file - * descriptor. In the case of blocking operations, this function can just be - * called in a loop, and block. - * - * The function returns linenoiseEditMore to signal that line editing is still - * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise - * the function returns the pointer to the heap-allocated buffer with the - * edited line, that the user should free with linenoiseFree(). - * - * On special conditions, NULL is returned and errno is populated: - * - * EAGAIN if the user pressed Ctrl-C - * ENOENT if the user pressed Ctrl-D - * - * Some other errno: I/O error. - */ -const char * linenoiseEditFeed(struct linenoiseState * l) { - /* Not a TTY, pass control to line reading without character count - * limits. */ - if (!isatty(l->ifd)) return linenoiseNoTTY(); - - int c; - int nread; - char cbuf[32]; - - nread = readCode(l->ifd, cbuf, sizeof(cbuf), &c); - if (nread <= 0) return NULL; - - auto esc_type = ESC_NULL; - if (c == ESC) { - esc_type = readEscapeSequence(l); - } - - /* Only autocomplete when the callback is set. It returns < 0 when - * there was an error reading from fd. Otherwise it will return the - * character that should be handled next. */ - if ((l->in_completion || c == 9) && completionCallback != NULL) { - c = completeLine(l, c, esc_type); - /* Read next character when 0 */ - if (c == 0) return linenoiseEditMore; - } - - return processInputCharacter(l, c, cbuf, nread, esc_type); -} - -/* This is part of the multiplexed linenoise API. See linenoiseEditStart() - * for more information. This function is called when linenoiseEditFeed() - * returns something different than NULL. At this point the user input - * is in the buffer, and we can restore the terminal in normal mode. */ -void linenoiseEditStop(struct linenoiseState *l) { - if (!isatty(l->ifd)) return; - disableRawMode(l->ifd); - printf("\n"); -} - -/* This just implements a blocking loop for the multiplexed API. - * In many applications that are not event-driven, we can just call - * the blocking linenoise API, wait for the user to complete the editing - * and return the buffer. */ -static const char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) -{ - struct linenoiseState l; - - /* Editing without a buffer is invalid. */ - if (buflen == 0) { - errno = EINVAL; - return NULL; - } - - linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); - const char *res; - while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); - linenoiseEditStop(&l); - return res; -} - -/* This special mode is used by linenoise in order to print scan codes - * on screen for debugging / development purposes. It is implemented - * by the linenoise_example program using the --keycodes option. */ -void linenoisePrintKeyCodes(void) { - char quit[4]; - - printf("Linenoise key codes debugging mode.\n" - "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); - if (enableRawMode(STDIN_FILENO) == -1) return; - memset(quit,' ',4); - while(1) { - char c; - int nread; - - nread = read(STDIN_FILENO,&c,1); - if (nread <= 0) continue; - memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ - quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ - if (memcmp(quit,"quit",sizeof(quit)) == 0) break; - - printf("'%c' %02x (%d) (type quit to exit)\n", isprint((int) c) ? c : '?', (int) c, (int) c); - printf("\r"); /* Go left edge manually, we are in raw mode. */ - fflush(stdout); - } - disableRawMode(STDIN_FILENO); -} - -/* This function is called when linenoise() is called with the standard - * input file descriptor not attached to a TTY. So for example when the - * program using linenoise is called in pipe or with a file redirected - * to its standard input. In this case, we want to be able to return the - * line regardless of its length (by default we are limited to 4k). */ -static char *linenoiseNoTTY(void) { - char *line = NULL; - size_t len = 0, maxlen = 0; - - while(1) { - if (len == maxlen) { - if (maxlen == 0) maxlen = 16; - maxlen *= 2; - char *oldval = line; - line = (char*) realloc(line,maxlen); - if (line == NULL) { - if (oldval) free(oldval); - return NULL; - } - } - int c = fgetc(stdin); - if (c == EOF || c == '\n') { - if (c == EOF && len == 0) { - free(line); - return NULL; - } else { - line[len] = '\0'; - return line; - } - } else { - line[len] = c; - len++; - } - } -} - -/* The high level function that is the main API of the linenoise library. - * This function checks if the terminal has basic capabilities, just checking - * for a blacklist of stupid terminals, and later either calls the line - * editing function or uses dummy fgets() so that you will be able to type - * something even in the most desperate of the conditions. */ -const char *linenoise(const char *prompt) { - char buf[LINENOISE_MAX_LINE]; - - if (!isatty(STDIN_FILENO)) { - /* Not a tty: read from file / pipe. In this mode we don't want any - * limit to the line size, so we call a function to handle that. */ - return linenoiseNoTTY(); - } else if (isUnsupportedTerm()) { - size_t len; - - printf("%s",prompt); - fflush(stdout); - if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; - len = strlen(buf); - while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { - len--; - buf[len] = '\0'; - } - return strdup(buf); - } else { - const char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); - return retval; - } -} - -/* This is just a wrapper the user may want to call in order to make sure - * the linenoise returned buffer is freed with the same allocator it was - * created with. Useful when the main program is using an alternative - * allocator. */ -void linenoiseFree(void *ptr) { - if (ptr == linenoiseEditMore) return; // Protect from API misuse. - free(ptr); -} - -/* ================================ History ================================= */ - -/* Free the history, but does not reset it. Only used when we have to - * exit() to avoid memory leaks are reported by valgrind & co. */ -static void freeHistory(void) { - if (history) { - int j; - - for (j = 0; j < history_len; j++) - free(history[j]); - free(history); - } -} - -/* At exit we'll try to fix the terminal to the initial conditions. */ -static void linenoiseAtExit(void) { - disableRawMode(STDIN_FILENO); - freeHistory(); -} - -/* This is the API call to add a new entry in the linenoise history. - * It uses a fixed array of char pointers that are shifted (memmoved) - * when the history max length is reached in order to remove the older - * entry and make room for the new one, so it is not exactly suitable for huge - * histories, but will work well for a few hundred of entries. - * - * Using a circular buffer is smarter, but a bit more complex to handle. */ -int linenoiseHistoryAdd(const char *line) { - char *linecopy; - - if (history_max_len == 0) return 0; - - /* Initialization on first call. */ - if (history == NULL) { - history = (char**) malloc(sizeof(char*)*history_max_len); - if (history == NULL) return 0; - memset(history,0,(sizeof(char*)*history_max_len)); - } - - /* Don't add duplicated lines. */ - if (history_len && !strcmp(history[history_len-1], line)) return 0; - - /* Add an heap allocated copy of the line in the history. - * If we reached the max length, remove the older line. */ - linecopy = strdup(line); - if (!linecopy) return 0; - if (history_len == history_max_len) { - free(history[0]); - memmove(history,history+1,sizeof(char*)*(history_max_len-1)); - history_len--; - } - history[history_len] = linecopy; - history_len++; - return 1; -} - -/* Set the maximum length for the history. This function can be called even - * if there is already some history, the function will make sure to retain - * just the latest 'len' elements if the new history length value is smaller - * than the amount of items already inside the history. */ -int linenoiseHistorySetMaxLen(int len) { - char **new_ptr; - - if (len < 1) return 0; - if (history) { - int tocopy = history_len; - - new_ptr = (char**) malloc(sizeof(char*)*len); - if (new_ptr == NULL) return 0; - - /* If we can't copy everything, free the elements we'll not use. */ - if (len < tocopy) { - int j; - - for (j = 0; j < tocopy-len; j++) free(history[j]); - tocopy = len; - } - memset(new_ptr,0,sizeof(char*)*len); - memcpy(new_ptr,history+(history_len-tocopy), sizeof(char*)*tocopy); - free(history); - history = new_ptr; - } - history_max_len = len; - if (history_len > history_max_len) - history_len = history_max_len; - return 1; -} - -/* Save the history in the specified file. On success 0 is returned - * otherwise -1 is returned. */ -int linenoiseHistorySave(const char *filename) { - mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); - File file; - file.open(filename, "w"); - umask(old_umask); - if (file.file == NULL) { - return -1; - } - chmod(filename,S_IRUSR|S_IWUSR); - for (int j = 0; j < history_len; ++j) { - fprintf(file.file, "%s\n", history[j]); - } - - return 0; -} - -/* Load the history from the specified file. If the file does not exist - * zero is returned and no operation is performed. - * - * If the file exists and the operation succeeded 0 is returned, otherwise - * on error -1 is returned. */ -int linenoiseHistoryLoad(const char *filename) { - File file; - file.open(filename, "r"); - char buf[LINENOISE_MAX_LINE]; - if (file.file == NULL) { - return -1; - } - - while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) { - char *p; - - p = strchr(buf,'\r'); - if (!p) p = strchr(buf,'\n'); - if (p) *p = '\0'; - linenoiseHistoryAdd(buf); - } - return 0; -} -#endif diff --git a/tools/run/linenoise.cpp/linenoise.h b/tools/run/linenoise.cpp/linenoise.h deleted file mode 100644 index 9823ca36d..000000000 --- a/tools/run/linenoise.cpp/linenoise.h +++ /dev/null @@ -1,137 +0,0 @@ -/* linenoise.h -- VERSION 1.0 - * - * Guerrilla line editing library against the idea that a line editing lib - * needs to be 20,000 lines of C++ code. - * - * See linenoise.cpp for more information. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __LINENOISE_H -#define __LINENOISE_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include /* For size_t. */ -#include - -extern const char * linenoiseEditMore; - -/* The linenoiseState structure represents the state during line editing. - * We pass this state to functions implementing specific editing - * functionalities. */ -struct linenoiseState { - int in_completion; /* The user pressed TAB and we are now in completion - * mode, so input is handled by completeLine(). */ - size_t completion_idx; /* Index of next completion to propose. */ - int ifd; /* Terminal stdin file descriptor. */ - int ofd; /* Terminal stdout file descriptor. */ - char * buf; /* Edited line buffer. */ - size_t buflen; /* Edited line buffer size. */ - const char * prompt; /* Prompt to display. */ - size_t plen; /* Prompt length. */ - size_t pos; /* Current cursor position. */ - size_t oldcolpos; /* Previous refresh cursor column position. */ - size_t len; /* Current edited line length. */ - size_t cols; /* Number of columns in terminal. */ - size_t oldrows; /* Rows used by last refreshed line (multiline mode) */ - int history_index; /* The history index we are currently editing. */ -}; - -struct linenoiseCompletions { - size_t len = 0; - char ** cvec = nullptr; - bool to_free = true; - - ~linenoiseCompletions() { - if (!to_free) { - return; - } - - for (size_t i = 0; i < len; ++i) { - free(cvec[i]); - } - - free(cvec); - } -}; - -/* Non blocking API. */ -int linenoiseEditStart(struct linenoiseState * l, int stdin_fd, int stdout_fd, char * buf, size_t buflen, - const char * prompt); -const char * linenoiseEditFeed(struct linenoiseState * l); -void linenoiseEditStop(struct linenoiseState * l); -void linenoiseHide(struct linenoiseState * l); -void linenoiseShow(struct linenoiseState * l); - -/* Blocking API. */ -const char * linenoise(const char * prompt); -void linenoiseFree(void * ptr); - -/* Completion API. */ -typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); -typedef const char *(linenoiseHintsCallback) (const char *, int * color, int * bold); -typedef void(linenoiseFreeHintsCallback)(const char *); -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); -void linenoiseSetHintsCallback(linenoiseHintsCallback *); -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); -void linenoiseAddCompletion(linenoiseCompletions *, const char *); - -/* History API. */ -int linenoiseHistoryAdd(const char * line); -int linenoiseHistorySetMaxLen(int len); -int linenoiseHistorySave(const char * filename); -int linenoiseHistoryLoad(const char * filename); - -/* Other utilities. */ -void linenoiseClearScreen(void); -void linenoiseSetMultiLine(int ml); -void linenoisePrintKeyCodes(void); -void linenoiseMaskModeEnable(void); -void linenoiseMaskModeDisable(void); - -/* Encoding functions. */ -typedef size_t(linenoisePrevCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseNextCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseReadCode)(int fd, char * buf, size_t buf_len, int * c); - -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc); - -#ifdef __cplusplus -} -#endif - -#endif /* __LINENOISE_H */ diff --git a/tools/run/run.cpp b/tools/run/run.cpp deleted file mode 100644 index b90a7253c..000000000 --- a/tools/run/run.cpp +++ /dev/null @@ -1,1408 +0,0 @@ -#include "chat.h" -#include "common.h" -#include "llama-cpp.h" -#include "log.h" - -#include "linenoise.cpp/linenoise.h" - -#define JSON_ASSERT GGML_ASSERT -#include - -#if defined(_WIN32) -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -#endif - -#if defined(LLAMA_USE_CURL) -# include -#else -# include "http.h" -#endif - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) -[[noreturn]] static void sigint_handler(int) { - printf("\n" LOG_COL_DEFAULT); - exit(0); // not ideal, but it's the only way to guarantee exit in all cases -} -#endif - -GGML_ATTRIBUTE_FORMAT(1, 2) -static int printe(const char * fmt, ...) { - va_list args; - va_start(args, fmt); - const int ret = vfprintf(stderr, fmt, args); - va_end(args); - - return ret; -} - -static std::string strftime_fmt(const char * fmt, const std::tm & tm) { - std::ostringstream oss; - oss << std::put_time(&tm, fmt); - - return oss.str(); -} - -class Opt { - public: - int init(int argc, const char ** argv) { - ctx_params = llama_context_default_params(); - model_params = llama_model_default_params(); - context_size_default = ctx_params.n_batch; - n_threads_default = ctx_params.n_threads; - ngl_default = model_params.n_gpu_layers; - common_params_sampling sampling; - temperature_default = sampling.temp; - - if (argc < 2) { - printe("Error: No arguments provided.\n"); - print_help(); - return 1; - } - - // Parse arguments - if (parse(argc, argv)) { - printe("Error: Failed to parse arguments.\n"); - print_help(); - return 1; - } - - // If help is requested, show help and exit - if (help) { - print_help(); - return 2; - } - - ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; - ctx_params.n_ctx = ctx_params.n_batch; - ctx_params.n_threads = ctx_params.n_threads_batch = n_threads >= 0 ? n_threads : n_threads_default; - model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; - temperature = temperature >= 0 ? temperature : temperature_default; - - return 0; // Success - } - - llama_context_params ctx_params; - llama_model_params model_params; - std::string model_; - std::string chat_template_file; - std::string user; - bool use_jinja = false; - int context_size = -1, ngl = -1, n_threads = -1; - float temperature = -1; - bool verbose = false; - - private: - int context_size_default = -1, ngl_default = -1, n_threads_default = -1; - float temperature_default = -1; - bool help = false; - - bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) { - return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atoi(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atof(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = argv[++i]; - - return 0; - } - - int parse_options_with_value(int argc, const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { - if (handle_option_with_value(argc, argv, i, context_size) == 1) { - return 1; - } - } else if (options_parsing && - (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) { - if (handle_option_with_value(argc, argv, i, ngl) == 1) { - return 1; - } - } else if (options_parsing && (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--threads") == 0)) { - if (handle_option_with_value(argc, argv, i, n_threads) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--temp") == 0) { - if (handle_option_with_value(argc, argv, i, temperature) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0) { - if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) { - return 1; - } - use_jinja = true; - } else { - return 2; - } - - return 0; - } - - int parse_options(const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { - verbose = true; - } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { - use_jinja = true; - } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { - help = true; - return 0; - } else if (options_parsing && strcmp(argv[i], "--") == 0) { - options_parsing = false; - } else { - return 2; - } - - return 0; - } - - int parse_positional_args(const char ** argv, int & i, int & positional_args_i) { - if (positional_args_i == 0) { - if (!argv[i][0] || argv[i][0] == '-') { - return 1; - } - - ++positional_args_i; - model_ = argv[i]; - } else if (positional_args_i == 1) { - ++positional_args_i; - user = argv[i]; - } else { - user += " " + std::string(argv[i]); - } - - return 0; - } - - int parse(int argc, const char ** argv) { - bool options_parsing = true; - for (int i = 1, positional_args_i = 0; i < argc; ++i) { - int ret = parse_options_with_value(argc, argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - ret = parse_options(argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - if (parse_positional_args(argv, i, positional_args_i)) { - return 1; - } - } - - if (model_.empty()) { - return 1; - } - - return 0; - } - - void print_help() const { - printf( - "Description:\n" - " Runs a llm\n" - "\n" - "Usage:\n" - " llama-run [options] model [prompt]\n" - "\n" - "Options:\n" - " -c, --context-size \n" - " Context size (default: %d)\n" - " --chat-template-file \n" - " Path to the file containing the chat template to use with the model.\n" - " Only supports jinja templates and implicitly sets the --jinja flag.\n" - " --jinja\n" - " Use jinja templating for the chat template of the model\n" - " -n, -ngl, --ngl \n" - " Number of GPU layers (default: %d)\n" - " --temp \n" - " Temperature (default: %.1f)\n" - " -t, --threads \n" - " Number of threads to use during generation (default: %d)\n" - " -v, --verbose, --log-verbose\n" - " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n" - " -h, --help\n" - " Show help message\n" - "\n" - "Commands:\n" - " model\n" - " Model is a string with an optional prefix of \n" - " huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n" - " If no protocol is specified and a file exists in the specified\n" - " path, file:// is assumed, otherwise if a file does not exist in\n" - " the specified path, ollama:// is assumed. Models that are being\n" - " pulled are downloaded with .partial extension while being\n" - " downloaded and then renamed as the file without the .partial\n" - " extension when complete.\n" - "\n" - "Examples:\n" - " llama-run llama3\n" - " llama-run ollama://granite-code\n" - " llama-run ollama://smollm:135m\n" - " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run https://example.com/some-file1.gguf\n" - " llama-run some-file2.gguf\n" - " llama-run file://some-file3.gguf\n" - " llama-run --ngl 999 some-file4.gguf\n" - " llama-run --ngl 999 some-file5.gguf Hello World\n", - context_size_default, ngl_default, temperature_default, n_threads_default); - } -}; - -struct progress_data { - size_t file_size = 0; - std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - bool printed = false; -}; - -static int get_terminal_width() { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO csbi; - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - return csbi.srWindow.Right - csbi.srWindow.Left + 1; -#else - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - return w.ws_col; -#endif -} - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = ggml_fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { -# ifdef _WIN32 - fd = _fileno(file); - hFile = (HANDLE) _get_osfhandle(fd); - if (hFile == INVALID_HANDLE_VALUE) { - fd = -1; - - return 1; - } - - OVERLAPPED overlapped = {}; - if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, - &overlapped)) { - fd = -1; - - return 1; - } -# else - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } -# endif - } - - return 0; - } - - std::string to_string() { - fseek(file, 0, SEEK_END); - const size_t size = ftell(file); - fseek(file, 0, SEEK_SET); - std::string out; - out.resize(size); - const size_t read_size = fread(&out[0], 1, size, file); - if (read_size != size) { - printe("Error reading file: %s", strerror(errno)); - } - - return out; - } - - ~File() { - if (fd >= 0) { -# ifdef _WIN32 - if (hFile != INVALID_HANDLE_VALUE) { - OVERLAPPED overlapped = {}; - UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped); - } -# else - flock(fd, LOCK_UN); -# endif - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -# ifdef _WIN32 - HANDLE hFile = nullptr; -# endif -}; - -class HttpClient { - public: - int init(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - if (std::filesystem::exists(output_file)) { - return 0; - } - - std::string output_file_partial; - - if (!output_file.empty()) { - output_file_partial = output_file + ".partial"; - } - - if (download(url, headers, output_file_partial, progress, response_str)) { - return 1; - } - - if (!output_file.empty()) { - try { - std::filesystem::rename(output_file_partial, output_file); - } catch (const std::filesystem::filesystem_error & e) { - printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what()); - return 1; - } - } - - return 0; - } - -#ifdef LLAMA_USE_CURL - - ~HttpClient() { - if (chunk) { - curl_slist_free_all(chunk); - } - - if (curl) { - curl_easy_cleanup(curl); - } - } - - private: - CURL * curl = nullptr; - struct curl_slist * chunk = nullptr; - - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - curl = curl_easy_init(); - if (!curl) { - return 1; - } - - progress_data data; - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - - return 1; - } - - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - - return 1; - } - } - - set_write_options(response_str, out); - data.file_size = set_resume_point(output_file); - set_progress_options(progress, data); - set_headers(headers); - CURLcode res = perform(url); - if (res != CURLE_OK){ - printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); - return 1; - } - - return 0; - } - - void set_write_options(std::string * response_str, const File & out) { - if (response_str) { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); - } else { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file); - } - } - - size_t set_resume_point(const std::string & output_file) { - size_t file_size = 0; - if (std::filesystem::exists(output_file)) { - file_size = std::filesystem::file_size(output_file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); - } - - return file_size; - } - - void set_progress_options(bool progress, progress_data & data) { - if (progress) { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); - curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress); - } - } - - void set_headers(const std::vector & headers) { - if (!headers.empty()) { - if (chunk) { - curl_slist_free_all(chunk); - chunk = 0; - } - - for (const auto & header : headers) { - chunk = curl_slist_append(chunk, header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); - } - } - - CURLcode perform(const std::string & url) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); - curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); -#ifdef _WIN32 - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - return curl_easy_perform(curl); - } - -#else // LLAMA_USE_CURL is not defined - -#define curl_off_t long long // temporary hack - - private: - // this is a direct translation of the cURL download() above - int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - try { - auto [cli, url_parts] = common_http_client(url); - - httplib::Headers headers; - for (const auto & h : headers_vec) { - size_t pos = h.find(':'); - if (pos != std::string::npos) { - headers.emplace(h.substr(0, pos), h.substr(pos + 2)); - } - } - - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - return 1; - } - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - return 1; - } - } - - size_t resume_offset = 0; - if (!output_file.empty() && std::filesystem::exists(output_file)) { - resume_offset = std::filesystem::file_size(output_file); - if (resume_offset > 0) { - headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); - } - } - - progress_data data; - data.file_size = resume_offset; - - long long total_size = 0; - long long received_this_session = 0; - - auto response_handler = - [&](const httplib::Response & response) { - if (resume_offset > 0 && response.status != 206) { - printe("\nServer does not support resuming. Restarting download.\n"); - out.file = freopen(output_file.c_str(), "wb", out.file); - if (!out.file) { - return false; - } - data.file_size = 0; - } - if (progress) { - if (response.has_header("Content-Length")) { - total_size = std::stoll(response.get_header_value("Content-Length")); - } else if (response.has_header("Content-Range")) { - auto range = response.get_header_value("Content-Range"); - auto slash = range.find('/'); - if (slash != std::string::npos) { - total_size = std::stoll(range.substr(slash + 1)); - } - } - } - return true; - }; - - auto content_receiver = - [&](const char * chunk, size_t length) { - if (out.file && fwrite(chunk, 1, length, out.file) != length) { - return false; - } - if (response_str) { - response_str->append(chunk, length); - } - received_this_session += length; - - if (progress && total_size > 0) { - update_progress(&data, total_size, received_this_session, 0, 0); - } - return true; - }; - - auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); - - if (data.printed) { - printe("\n"); - } - - if (!res) { - auto err = res.error(); - printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); - return 1; - } - - if (res->status >= 400) { - printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); - return 1; - } - - } catch (const std::exception & e) { - printe("HTTP request failed: %s\n", e.what()); - return 1; - } - return 0; - } - -#endif // LLAMA_USE_CURL - - static std::string human_readable_time(double seconds) { - int hrs = static_cast(seconds) / 3600; - int mins = (static_cast(seconds) % 3600) / 60; - int secs = static_cast(seconds) % 60; - - if (hrs > 0) { - return string_format("%dh %02dm %02ds", hrs, mins, secs); - } else if (mins > 0) { - return string_format("%dm %02ds", mins, secs); - } else { - return string_format("%ds", secs); - } - } - - static std::string human_readable_size(curl_off_t size) { - static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; - char length = sizeof(suffix) / sizeof(suffix[0]); - int i = 0; - double dbl_size = size; - if (size > 1024) { - for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { - dbl_size = size / 1024.0; - } - } - - return string_format("%.2f %s", dbl_size, suffix[i]); - } - - static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, - curl_off_t) { - progress_data * data = static_cast(ptr); - if (total_to_download <= 0) { - return 0; - } - - total_to_download += data->file_size; - const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; - const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download); - std::string progress_prefix = generate_progress_prefix(percentage); - - const double speed = calculate_speed(now_downloaded, data->start_time); - const double tim = (total_to_download - now_downloaded) / speed; - std::string progress_suffix = - generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim); - - int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix); - std::string progress_bar; - generate_progress_bar(progress_bar_width, percentage, progress_bar); - - print_progress(progress_prefix, progress_bar, progress_suffix); - data->printed = true; - - return 0; - } - - static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) { - return (now_downloaded_plus_file_size * 100) / total_to_download; - } - - static std::string generate_progress_prefix(curl_off_t percentage) { - return string_format("%3ld%% |", static_cast(percentage)); - } - - static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { - const auto now = std::chrono::steady_clock::now(); - const std::chrono::duration elapsed_seconds = now - start_time; - return now_downloaded / elapsed_seconds.count(); - } - - static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download, - double speed, double estimated_time) { - const int width = 10; - return string_format("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), - width, human_readable_size(total_to_download).c_str(), width, - human_readable_size(speed).c_str(), width, human_readable_time(estimated_time).c_str()); - } - - static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) { - int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3; - if (progress_bar_width < 1) { - progress_bar_width = 1; - } - - return progress_bar_width; - } - - static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage, - std::string & progress_bar) { - const curl_off_t pos = (percentage * progress_bar_width) / 100; - for (int i = 0; i < progress_bar_width; ++i) { - progress_bar.append((i < pos) ? "█" : " "); - } - - return progress_bar; - } - - static void print_progress(const std::string & progress_prefix, const std::string & progress_bar, - const std::string & progress_suffix) { - printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str()); - } - // Function to write data to a file - static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { - FILE * out = static_cast(stream); - return fwrite(ptr, size, nmemb, out); - } - - // Function to capture data into a string - static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { - std::string * str = static_cast(stream); - str->append(static_cast(ptr), size * nmemb); - return size * nmemb; - } - -}; - -class LlamaData { - public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; - std::vector messages; // TODO: switch to common_chat_msg - std::list msg_strs; - std::vector fmtted; - - int init(Opt & opt) { - model = initialize_model(opt); - if (!model) { - return 1; - } - - context = initialize_context(model, opt); - if (!context) { - return 1; - } - - sampler = initialize_sampler(opt); - - return 0; - } - - private: - int download(const std::string & url, const std::string & output_file, const bool progress, - const std::vector & headers = {}, std::string * response_str = nullptr) { - HttpClient http; - if (http.init(url, headers, output_file, progress, response_str)) { - return 1; - } - - return 0; - } - - // Helper function to handle model tag extraction and URL construction - std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { - std::string model_tag = "latest"; - const size_t colon_pos = model.find(':'); - if (colon_pos != std::string::npos) { - model_tag = model.substr(colon_pos + 1); - model = model.substr(0, colon_pos); - } - - std::string url = base_url + model + "/manifests/" + model_tag; - - return { model, url }; - } - - // Helper function to download and parse the manifest - int download_and_parse_manifest(const std::string & url, const std::vector & headers, - nlohmann::json & manifest) { - std::string manifest_str; - int ret = download(url, "", false, headers, &manifest_str); - if (ret) { - return ret; - } - - manifest = nlohmann::json::parse(manifest_str); - - return 0; - } - - int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) { - // Find the second occurrence of '/' after protocol string - size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - std::string hfr, hff; - std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; - std::string url; - - if (pos == std::string::npos) { - auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/"); - hfr = model_name; - - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, headers, manifest); - if (ret) { - return ret; - } - - hff = manifest["ggufFile"]["rfilename"]; - } else { - hfr = model.substr(0, pos); - hff = model.substr(pos + 1); - } - - url = model_endpoint + hfr + "/resolve/main/" + hff; - - return download(url, bn, true, headers); - } - - int modelscope_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = "https://modelscope.cn/models/"; - return dl_from_endpoint(model_endpoint, model, bn); - } - - int huggingface_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = get_model_endpoint(); - return dl_from_endpoint(model_endpoint, model, bn); - } - - int ollama_dl(std::string & model, const std::string & bn) { - const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, {}, manifest); - if (ret) { - return ret; - } - - std::string layer; - for (const auto & l : manifest["layers"]) { - if (l["mediaType"] == "application/vnd.ollama.image.model") { - layer = l["digest"]; - break; - } - } - - std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; - - return download(blob_url, bn, true, headers); - } - - int github_dl(const std::string & model, const std::string & bn) { - std::string repository = model; - std::string branch = "main"; - const size_t at_pos = model.find('@'); - if (at_pos != std::string::npos) { - repository = model.substr(0, at_pos); - branch = model.substr(at_pos + 1); - } - - const std::vector repo_parts = string_split(repository, "/"); - if (repo_parts.size() < 3) { - printe("Invalid GitHub repository format\n"); - return 1; - } - - const std::string & org = repo_parts[0]; - const std::string & project = repo_parts[1]; - std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; - for (size_t i = 2; i < repo_parts.size(); ++i) { - url += "/" + repo_parts[i]; - } - - return download(url, bn, true); - } - - int s3_dl(const std::string & model, const std::string & bn) { - const size_t slash_pos = model.find('/'); - if (slash_pos == std::string::npos) { - return 1; - } - - const std::string bucket = model.substr(0, slash_pos); - const std::string key = model.substr(slash_pos + 1); - const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); - const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); - if (!access_key || !secret_key) { - printe("AWS credentials not found in environment\n"); - return 1; - } - - // Generate AWS Signature Version 4 headers - // (Implementation requires HMAC-SHA256 and date handling) - // Get current timestamp - const time_t now = time(nullptr); - const tm tm = *gmtime(&now); - const std::string date = strftime_fmt("%Y%m%d", tm); - const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); - const std::vector headers = { - "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + - "/us-east-1/s3/aws4_request", - "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime - }; - - const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; - - return download(url, bn, true, headers); - } - - std::string basename(const std::string & path) { - const size_t pos = path.find_last_of("/\\"); - if (pos == std::string::npos) { - return path; - } - - return path.substr(pos + 1); - } - - int rm_until_substring(std::string & model_, const std::string & substring) { - const std::string::size_type pos = model_.find(substring); - if (pos == std::string::npos) { - return 1; - } - - model_ = model_.substr(pos + substring.size()); // Skip past the substring - return 0; - } - - int resolve_model(std::string & model_) { - int ret = 0; - if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - rm_until_substring(model_, "://"); - - return ret; - } - - const std::string bn = basename(model_); - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || - string_starts_with(model_, "hf.co/")) { - rm_until_substring(model_, "hf.co/"); - rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, bn); - } else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) { - rm_until_substring(model_, "://"); - ret = modelscope_dl(model_, bn); - } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && - !string_starts_with(model_, "https://ollama.com/library/")) { - ret = download(model_, bn, true); - } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { - rm_until_substring(model_, "github:"); - rm_until_substring(model_, "://"); - ret = github_dl(model_, bn); - } else if (string_starts_with(model_, "s3://")) { - rm_until_substring(model_, "://"); - ret = s3_dl(model_, bn); - } else { // ollama:// or nothing - rm_until_substring(model_, "ollama.com/library/"); - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, bn); - } - - model_ = bn; - - return ret; - } - - // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(Opt & opt) { - ggml_backend_load_all(); - resolve_model(opt.model_); - printe("\r" LOG_CLR_TO_EOL "Loading model"); - llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); - if (!model) { - printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); - } - - printe("\r" LOG_CLR_TO_EOL); - return model; - } - - // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { - llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); - if (!context) { - printe("%s: error: failed to create the llama_context\n", __func__); - } - - return context; - } - - // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler(const Opt & opt) { - llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - - return sampler; - } -}; - -// Add a message to `messages` and store its content in `msg_strs` -static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { - llama_data.msg_strs.push_back(std::move(text)); - llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); -} - -// Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { - common_chat_templates_inputs inputs; - for (const auto & msg : llama_data.messages) { - common_chat_msg cmsg; - cmsg.role = msg.role; - cmsg.content = msg.content; - inputs.messages.push_back(cmsg); - } - inputs.add_generation_prompt = append; - inputs.use_jinja = use_jinja; - - auto chat_params = common_chat_templates_apply(tmpls, inputs); - // TODO: use other params for tool calls. - auto result = chat_params.prompt; - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); -} - -// Function to tokenize the prompt -static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; - int n_tokens = prompt.size() + 2 * is_first; - prompt_tokens.resize(n_tokens); - n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (n_tokens == std::numeric_limits::min()) { - printe("tokenization failed: input too large\n"); - return -1; - } - if (n_tokens < 0) { - prompt_tokens.resize(-n_tokens); - int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (check != -n_tokens) { - printe("failed to tokenize the prompt (size mismatch)\n"); - return -1; - } - n_tokens = check; - } else { - prompt_tokens.resize(n_tokens); - } - return n_tokens; -} - -// Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { - const int n_ctx = llama_n_ctx(ctx.get()); - const int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx.get()), 0); - if (n_ctx_used + batch.n_tokens > n_ctx) { - printf(LOG_COL_DEFAULT "\n"); - printe("context size exceeded\n"); - return 1; - } - - return 0; -} - -// convert the token to a string -static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { - char buf[256]; - int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - printe("failed to convert token to piece\n"); - return 1; - } - - piece = std::string(buf, n); - return 0; -} - -static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { - printf("%s", piece.c_str()); - fflush(stdout); - response += piece; -} - -// helper function to evaluate a prompt and generate a response -static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); - - std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { - return 1; - } - - // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); - llama_token new_token_id; - while (true) { - check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { - printe("failed to decode\n"); - return 1; - } - - // sample the next token, check is it an end of generation? - new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_vocab_is_eog(vocab, new_token_id)) { - break; - } - - std::string piece; - if (convert_token_to_string(vocab, new_token_id, piece)) { - return 1; - } - - print_word_and_concatenate_to_response(piece, response); - - // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); - } - - printf(LOG_COL_DEFAULT); - return 0; -} - -static int read_user_input(std::string & user_input) { - static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX"); - static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> "; -#ifdef WIN32 - printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix); - - std::getline(std::cin, user_input); - if (std::cin.eof()) { - printf("\n"); - return 1; - } -#else - std::unique_ptr line(const_cast(linenoise(prompt_prefix)), free); - if (!line) { - return 1; - } - - user_input = line.get(); -#endif - - if (user_input == "/bye") { - return 1; - } - - if (user_input.empty()) { - return 2; - } - -#ifndef WIN32 - linenoiseHistoryAdd(line.get()); -#endif - - return 0; // Should have data in happy path -} - -// Function to generate a response based on the prompt -static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, - const bool stdout_a_terminal) { - // Set response color - if (stdout_a_terminal) { - printf(LOG_COL_YELLOW); - } - - if (generate(llama_data, prompt, response)) { - printe("failed to generate response\n"); - return 1; - } - - // End response with color reset and newline - printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : ""); - return 0; -} - -// Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); - if (new_len < 0) { - printe("failed to apply the chat template\n"); - return -1; - } - - output_length = new_len; - return 0; -} - -// Helper function to handle user input -static int handle_user_input(std::string & user_input, const std::string & user) { - if (!user.empty()) { - user_input = user; - return 0; // No need for interactive input - } - - return read_user_input(user_input); // Returns true if input ends the loop -} - -static bool is_stdin_a_terminal() { -#if defined(_WIN32) - HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdin, &mode); -#else - return isatty(STDIN_FILENO); -#endif -} - -static bool is_stdout_a_terminal() { -#if defined(_WIN32) - HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdout, &mode); -#else - return isatty(STDOUT_FILENO); -#endif -} - -// Function to handle user input -static int get_user_input(std::string & user_input, const std::string & user) { - while (true) { - const int ret = handle_user_input(user_input, user); - if (ret == 1) { - return 1; - } - - if (ret == 2) { - continue; - } - - break; - } - - return 0; -} - -// Reads a chat template file to be used -static std::string read_chat_template_file(const std::string & chat_template_file) { - File file; - if (!file.open(chat_template_file, "r")) { - printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); - return ""; - } - - return file.to_string(); -} - -static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data, - const common_chat_templates_ptr & chat_templates, int & prev_len, - const bool stdout_a_terminal) { - add_message("user", opt.user.empty() ? user_input : opt.user, llama_data); - int new_len; - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) { - return 1; - } - - std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); - std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { - return 1; - } - - if (!opt.user.empty()) { - return 2; - } - - add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) { - return 1; - } - - return 0; -} - -// Main chat loop function -static int chat_loop(LlamaData & llama_data, const Opt & opt) { - int prev_len = 0; - llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - std::string chat_template; - if (!opt.chat_template_file.empty()) { - chat_template = read_chat_template_file(opt.chat_template_file); - } - - common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); - static const bool stdout_a_terminal = is_stdout_a_terminal(); - while (true) { - // Get user input - std::string user_input; - if (get_user_input(user_input, opt.user) == 1) { - return 0; - } - - const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal); - if (ret == 1) { - return 1; - } else if (ret == 2) { - break; - } - } - - return 0; -} - -static void log_callback(const enum ggml_log_level level, const char * text, void * p) { - const Opt * opt = static_cast(p); - if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) { - printe("%s", text); - } -} - -static std::string read_pipe_data() { - std::ostringstream result; - result << std::cin.rdbuf(); // Read all data from std::cin - return result.str(); -} - -static void ctrl_c_handling() { -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset(&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined(_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif -} - -int main(int argc, const char ** argv) { - ctrl_c_handling(); - Opt opt; - const int ret = opt.init(argc, argv); - if (ret == 2) { - return 0; - } else if (ret) { - return 1; - } - - if (!is_stdin_a_terminal()) { - if (!opt.user.empty()) { - opt.user += "\n\n"; - } - - opt.user += read_pipe_data(); - } - - llama_log_set(log_callback, &opt); - LlamaData llama_data; - if (llama_data.init(opt)) { - return 1; - } - - if (chat_loop(llama_data, opt)) { - return 1; - } - - return 0; -}