From c31fc8b966817b2f0b277fd28e04a189e388972a Mon Sep 17 00:00:00 2001 From: "Gilad S." <7817232+giladgd@users.noreply.github.com> Date: Sat, 4 Jan 2025 10:17:31 +0200 Subject: [PATCH 01/25] fix: Vulkan shader gen binary path (#11037) --- ggml/src/ggml-vulkan/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 6d46e5f24..9501de736 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -73,7 +73,7 @@ if (Vulkan_FOUND) OUTPUT ${_ggml_vk_header} ${_ggml_vk_source} - COMMAND ${_ggml_vk_genshaders_cmd} + COMMAND "$/${_ggml_vk_genshaders_cmd}" --glslc ${Vulkan_GLSLC_EXECUTABLE} --input-dir ${_ggml_vk_input_dir} --output-dir ${_ggml_vk_output_dir} From db68c93b57bfdf6da1fbdae81080382d6998cbc9 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 19 Dec 2024 03:50:12 +0100 Subject: [PATCH 02/25] ggml : improve inputs log sched_print_assignments (ggml/1053) This commit attempts to improve the log message for the inputs of the splits in the sched_print_assignments function. The motivation for this change is that currently even if there are no inputs a colon is displayed at the end of the line, which can make it a little confusing when reading the output as it could be interpreted as the line below are inputs when they are in fact nodes. With this change the colon will only be printed if there actually are inputs. --- ggml/src/ggml-backend.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index fdb4b986f..e2d6c4056 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -795,9 +795,12 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; - GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { + if (j == 0) { + GGML_LOG_DEBUG(": "); + } GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); } From 5e3b08d606b5b0caaea16541b504c3bba8f3ec1d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 4 Jan 2025 10:53:54 +0200 Subject: [PATCH 03/25] ggml : do not install metal source when embed library (ggml/1054) --- ggml/CMakeLists.txt | 20 -------------------- ggml/src/ggml-metal/CMakeLists.txt | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index e33d97482..393506533 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -252,26 +252,6 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") install(TARGETS ggml LIBRARY PUBLIC_HEADER) install(TARGETS ggml-base LIBRARY) -# FIXME: this should be done in the backend cmake files -if (GGML_METAL) - # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY? - install( - FILES src/ggml-metal/ggml-metal.metal - PERMISSIONS - OWNER_READ - OWNER_WRITE - GROUP_READ - WORLD_READ - DESTINATION ${CMAKE_INSTALL_BINDIR}) - - if (NOT GGML_METAL_EMBED_LIBRARY) - install( - FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() -endif() - if (GGML_STANDALONE) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 1bad27206..89fcde2fa 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -103,3 +103,19 @@ else() DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib ) endif() # GGML_METAL_EMBED_LIBRARY + +if (NOT GGML_METAL_EMBED_LIBRARY) + install( + FILES src/ggml-metal/ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) + + install( + FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() From 78c678517530d411b4263341cdb4dc28c9d117c8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 4 Jan 2025 10:54:01 +0200 Subject: [PATCH 04/25] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index b4ac38bbf..b67445ecd 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -e6d93f40dffe8733d5d72f1d8fa6b3ca27ae899f +a2af72be7baf5b1f4a33d34e77e509e5e85b7cd7 From 46be942214e295cd34660bbbd6b846155d1c36a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sat, 4 Jan 2025 09:33:31 -0500 Subject: [PATCH 05/25] llama : add support for the cohere2 model architecture (#10900) --- convert_hf_to_gguf.py | 18 +++++ gguf-py/gguf/constants.py | 14 ++++ src/llama-arch.cpp | 16 ++++ src/llama-arch.h | 1 + src/llama-model.cpp | 11 +++ src/llama.cpp | 161 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 221 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4e6c0f60c..d4441bbe9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3373,6 +3373,24 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("Cohere2ForCausalLM") +class Cohere2Model(Model): + model_arch = gguf.MODEL_ARCH.COHERE2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + rotary_pct = self.hparams["rotary_pct"] + hidden_size = self.hparams["hidden_size"] + num_attention_heads = self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + @Model.register("OlmoForCausalLM") @Model.register("OLMoForCausalLM") class OlmoModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 273370370..cdf79673b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -255,6 +255,7 @@ class MODEL_ARCH(IntEnum): MAMBA = auto() XVERSE = auto() COMMAND_R = auto() + COHERE2 = auto() DBRX = auto() OLMO = auto() OLMO2 = auto() @@ -437,6 +438,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.OLMO2: "olmo2", @@ -1136,6 +1138,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_Q_NORM, ], + MODEL_ARCH.COHERE2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.DBRX: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a60038385..fea4b21d3 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -39,6 +39,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_OLMO2, "olmo2" }, @@ -807,6 +808,21 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DBRX, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 446e72eeb..10bd619a4 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -43,6 +43,7 @@ enum llm_arch { LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, + LLM_ARCH_COHERE2, LLM_ARCH_DBRX, LLM_ARCH_OLMO, LLM_ARCH_OLMO2, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ace0ba262..c356abded 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -786,6 +786,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_COHERE2: + { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_DBRX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2031,6 +2041,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: + case LLM_ARCH_COHERE2: case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: diff --git a/src/llama.cpp b/src/llama.cpp index d7110b90b..50e9191fa 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1552,6 +1552,32 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_COHERE2: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + llama_model_loader::TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + break; case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7633,6 +7659,137 @@ struct llm_build_context { } + struct ggml_cgraph * build_cohere2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const float f_logit_scale = hparams.f_logit_scale; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + // cohere2 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + + // sliding window switch pattern + const int32_t sliding_window_pattern = 4; + + for (int il = 0; il < n_layer; ++il) { + // three layers sliding window attention (window size 4096) and ROPE + // fourth layer uses global attention without positional embeddings + const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); + struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + struct ggml_tensor * ffn_inp = cur; + + // self-attention + { + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + if (is_sliding) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + } else { + // For non-sliding layers, just reshape without applying RoPE + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + } + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + struct ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, + cb, il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + // ref: https://allenai.org/olmo // based on the original build_llama() function, changes: // * non-parametric layer norm @@ -10384,6 +10541,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_command_r(); } break; + case LLM_ARCH_COHERE2: + { + result = llm.build_cohere2(); + } break; case LLM_ARCH_DBRX: { result = llm.build_dbrx(); From f922a9c542ee117550a168395c63ea79261f5c99 Mon Sep 17 00:00:00 2001 From: matt23654 Date: Sat, 4 Jan 2025 16:10:30 +0000 Subject: [PATCH 06/25] [GGML][RPC] Support for models with non-512-aligned tensors over RPC. (#11047) * Added init tensor calling code * Added get_alloc_size forwarding * Cleaned up and improved type/error handling. * fix: remove trailing whitespaces. * Cleanup and use GGML error logging functions. * Handle potentially dangerous edge cases. * Apply suggestions from code review Co-authored-by: Diego Devesa --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-rpc/ggml-rpc.cpp | 140 +++++++++++++++++++++++++++++++-- 1 file changed, 134 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 431082426..2213aba9f 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -93,9 +93,23 @@ enum rpc_cmd { RPC_CMD_COPY_TENSOR, RPC_CMD_GRAPH_COMPUTE, RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_COUNT, }; +struct rpc_msg_get_alloc_size_req { + rpc_tensor tensor; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + struct rpc_msg_alloc_buffer_req { uint64_t size; }; @@ -461,10 +475,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - UNUSED(buffer); - if (ggml_is_quantized(tensor->type)) { - // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized - GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + GGML_ASSERT(status); } } @@ -577,8 +599,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { } static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - UNUSED(buft); - return ggml_nbytes(tensor); + // See comments in init_tensor. + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request; + + request.tensor = serialize_tensor(tensor); + + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + GGML_ASSERT(status); + + return response.alloc_size; + } else { + return ggml_nbytes(tensor); + } } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -757,6 +794,8 @@ public: bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -770,6 +809,36 @@ private: std::unordered_set buffers; }; +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + ggml_free(ctx); + return false; + } + + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backend); + } else { + buft = tensor->buffer->buft; + } + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); + + ggml_free(ctx); + return true; +} + void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); @@ -905,6 +974,40 @@ bool rpc_server::set_tensor(const std::vector & input) { return true; } +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); + ggml_free(ctx); + return false; + } + + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + return true; +} + bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -1058,6 +1161,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + server.get_alloc_size(request, response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_GET_ALIGNMENT: { if (!recv_msg(sockfd, nullptr, 0)) { return; @@ -1133,6 +1248,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } break; } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sockfd, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { From 9394bbd484f802ce80d2858033583af3ef700d25 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Sat, 4 Jan 2025 21:06:11 +0100 Subject: [PATCH 07/25] llama : Add support for DeepSeek V3 (#11049) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : extend DEEPSEEK2 model architecture to support DeepseekV3ForCausalLM by adding EXPERT_WEIGHTS_NORM and EXPERT_GATING_FUNC model parameters and FFN_EXP_PROBS_B tensor type * vocab : add DeepSeek V3 pre-tokenizer regexes * unicode : handle ACCENT_MARK and SYMBOL categories in regex * llama : add DeepSeek V3 chat template, handle new model parameters and tensor types --------- Co-authored-by: Stanisław Szymczyk --- convert_hf_to_gguf.py | 23 +++++++++++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 10 ++++++++ gguf-py/gguf/gguf_writer.py | 7 ++++++ gguf-py/gguf/tensor_mapping.py | 4 +++ include/llama.h | 1 + src/llama-arch.cpp | 4 +++ src/llama-arch.h | 3 +++ src/llama-chat.cpp | 18 ++++++++++++++ src/llama-chat.h | 1 + src/llama-hparams.h | 12 +++++++-- src/llama-model.cpp | 23 +++++++++++++++++ src/llama-model.h | 2 ++ src/llama-vocab.cpp | 7 ++++++ src/llama.cpp | 45 +++++++++++++++++++++++++++++++--- src/unicode.cpp | 6 +++++ 16 files changed, 162 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d4441bbe9..01b58f976 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -687,6 +687,9 @@ class Model: if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") @@ -3849,6 +3852,7 @@ class DeepseekModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3870,6 +3874,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3882,6 +3895,16 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index fea23ddb4..56edc64a7 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -107,6 +107,7 @@ models = [ {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cdf79673b..9d0e7489f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -102,6 +102,8 @@ class Keys: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -313,6 +315,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -498,6 +501,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -1290,6 +1294,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1590,6 +1595,11 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3023b539a..4a0a65e3c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -715,6 +716,12 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7009a11d4..efe2a4aa4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -276,6 +276,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox diff --git a/include/llama.h b/include/llama.h index 7b305b299..a0d5ba5dd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -105,6 +105,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, }; enum llama_rope_type { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fea4b21d3..007d79f82 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -92,6 +92,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -984,6 +986,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, { @@ -1366,6 +1369,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 10bd619a4..45e458bb9 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -96,6 +96,8 @@ enum llm_kv { LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -231,6 +233,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index a07e9cf00..44670d3d8 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -45,6 +45,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, @@ -148,6 +149,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -453,6 +456,21 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct diff --git a/src/llama-chat.h b/src/llama-chat.h index 364318c27..b8e94d9ef 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -25,6 +25,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_VICUNA_ORCA, LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 3a76b71a4..a29f20ec4 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,13 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 + +enum llama_expert_gating_func_type { + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, +}; struct llama_hparams_posnet { uint32_t n_embd; @@ -54,7 +60,9 @@ struct llama_hparams { uint32_t n_expert_shared = 0; uint32_t n_norm_groups = 0; - float expert_weights_scale = 0.0; + float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; float f_norm_eps; float f_norm_rms_eps; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c356abded..405e0528f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -66,6 +66,7 @@ const char * llm_type_name(llm_type type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; @@ -125,6 +126,14 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { } } +static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) { + switch (type) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax"; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + std::string llama_model_arch_name (const llama_model & model) { return llm_arch_name(model.arch); } @@ -933,11 +942,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; case 60: model.type = e_model::MODEL_236B; break; + case 61: model.type = e_model::MODEL_671B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -1259,6 +1276,10 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) { tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -1941,6 +1962,8 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } diff --git a/src/llama-model.h b/src/llama-model.h index 01c780c41..ce038932d 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -63,6 +63,7 @@ enum llm_type { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, MODEL_LARGE, @@ -213,6 +214,7 @@ struct llama_layer { struct ggml_tensor * ffn_down_b = nullptr; // b2 struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; + struct ggml_tensor * ffn_exp_probs_b = nullptr; // mamba proj struct ggml_tensor * ssm_in = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 909e04871..3fcfcaa3f 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -382,6 +382,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "\\p{N}+", }; break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", diff --git a/src/llama.cpp b/src/llama.cpp index 50e9191fa..ea78ea487 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1857,6 +1857,7 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (n_expert == 0) { throw std::runtime_error("n_expert must be > 0"); @@ -2837,12 +2838,14 @@ static struct ggml_tensor * llm_build_moe_ffn( struct ggml_tensor * up_exps, struct ggml_tensor * gate_exps, struct ggml_tensor * down_exps, + struct ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, bool scale_w, float w_scale, +llama_expert_gating_func_type gating_op, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -2851,11 +2854,31 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: + { + probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: + { + probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] + } break; + default: + GGML_ABORT("fatal error"); + } cb(probs, "ffn_moe_probs", il); + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -3976,9 +3999,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -4628,9 +4653,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -4769,9 +4796,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -6017,9 +6046,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -8142,9 +8173,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -8539,9 +8572,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -8680,9 +8715,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, cb, il); cb(moe_out, "ffn_moe_out", il); @@ -8909,9 +8946,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - LLM_FFN_SILU, false, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (enum llama_expert_gating_func_type) hparams.expert_gating_func, cb, il); cb(moe_out, "ffn_moe_out", il); diff --git a/src/unicode.cpp b/src/unicode.cpp index 8ed6b1a51..7aca6544b 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -667,18 +667,24 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{N}", unicode_cpt_flags::NUMBER }, { "\\p{L}", unicode_cpt_flags::LETTER }, { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, + { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, + { "\\p{S}", unicode_cpt_flags::SYMBOL }, }; static const std::map k_ucat_cpt = { { unicode_cpt_flags::NUMBER, 0xD1 }, { unicode_cpt_flags::LETTER, 0xD2 }, { unicode_cpt_flags::PUNCTUATION, 0xD3 }, + { unicode_cpt_flags::ACCENT_MARK, 0xD4 }, + { unicode_cpt_flags::SYMBOL, 0xD5 }, }; static const std::map k_ucat_map = { { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9 { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints + { unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`| }; // compute collapsed codepoints only if needed by at least one regex From b56f079e28fda692f11a8b59200ceb815b05d419 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 4 Jan 2025 21:09:59 +0100 Subject: [PATCH 08/25] Vulkan: Add device-specific blacklist for coopmat for the AMD proprietary driver (#11074) * Vulkan: Add device-specific blacklist for coopmat for the AMD proprietary driver * Add (TM) to AMD name check --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 30 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 020e61280..d75cd6d61 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2040,6 +2040,8 @@ static void ggml_vk_load_shaders(vk_device& device) { std::cerr << "Done!" << std::endl; } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); + static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2175,9 +2177,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { - // Intel drivers don't support coopmat properly yet - // Only RADV supports coopmat properly on AMD + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { device->coopmat_support = false; } @@ -2515,7 +2515,6 @@ static vk_device ggml_vk_get_device(size_t idx) { return vk_instance.devices[idx]; } - static void ggml_vk_print_gpu_info(size_t idx) { GGML_ASSERT(idx < vk_instance.device_indices.size()); size_t dev_num = vk_instance.device_indices[idx]; @@ -2565,9 +2564,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } - if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) { - // Intel drivers don't support coopmat properly yet - // Only RADV supports coopmat properly on AMD + if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { coopmat_support = false; } @@ -8088,6 +8085,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Intel drivers don't support coopmat properly yet + return false; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + const std::string name = props.deviceName; + return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs + name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs + name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + } + return true; + default: + return true; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS From 46e3556e01b824e52395fb050b29804b6cff2a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 6 Jan 2025 02:33:52 +0100 Subject: [PATCH 09/25] CUDA: add BF16 support (#11093) * CUDA: add BF16 support --- ggml/src/ggml-cuda/convert.cu | 2 + ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- ggml/src/ggml-cuda/mmv.cu | 114 ++++++++++++++++++++---------- ggml/src/ggml-cuda/vendors/cuda.h | 1 + ggml/src/ggml-cuda/vendors/hip.h | 3 + ggml/src/ggml-cuda/vendors/musa.h | 3 + 6 files changed, 87 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 3896f956d..5b0dfacef 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c180adc84..0b06be729 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) @@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_BF16: #ifdef GGML_USE_MUSA if (a->type == GGML_TYPE_Q3_K) { return false; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index a4b4f6bc1..ac45f2d17 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -1,9 +1,9 @@ #include "common.cuh" #include "mmv.cuh" -template +template static __global__ void mul_mat_vec( - const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { const int64_t row = blockIdx.x; const int64_t channel = blockIdx.z; @@ -13,7 +13,6 @@ static __global__ void mul_mat_vec( y += channel *stride_channel_y; dst += channel *stride_channel_dst; - const half2 * x2 = (const half2 *) x; const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; @@ -28,28 +27,44 @@ static __global__ void mul_mat_vec( float sumf; - if (std::is_same::value) { + if constexpr (std::is_same::value) { + const half2 * x2 = (const half2 *) x; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same::value) { + const int * x2 = (const int *) x; sumf = 0.0f; for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmpx = __half22float2(x2[col2]); + const int tmpx = x2[col2]; const float2 tmpy = y2[col2]; - sumf += tmpx.x * tmpy.x; - sumf += tmpx.y * tmpy.y; + sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; } } else { -#ifdef FP16_AVAILABLE - half2 sumh2 = make_half2(0.0f, 0.0f); - - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmp = y2[col2]; - sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); - } - - sumf = __low2float(sumh2) + __high2float(sumh2); -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE + static_assert(std::is_same::value, "unsupported type"); } sumf = warp_reduce_sum(sumf); @@ -71,9 +86,9 @@ static __global__ void mul_mat_vec( dst[row] = sumf; } -template +template static void launch_mul_mat_vec_cuda( - const half * x, const float * y, float * dst, + const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, cudaStream_t stream) { @@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda( const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 64: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 96: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 128: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 160: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 192: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 224: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; case 256: { - mul_mat_vec<<>> + mul_mat_vec<<>> (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); } break; default: { @@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda( } } +template static void mul_mat_vec_cuda( - const half * x, const float * y, float * dst, + const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, enum ggml_prec prec, cudaStream_t stream) { switch (prec) { case GGML_PREC_DEFAULT: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, stream); } break; case GGML_PREC_F32: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, stream); } break; } } void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const half * src0_d = (const half *) src0->data; const float * src1_d = (const float *) src1->data; float * dst_d = (float *) dst->data; @@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } } void ggml_cuda_op_mul_mat_vec( @@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec( const int64_t channel_stride_y = 0; const int64_t channel_stride_dst = 0; - mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } GGML_UNUSED(ctx); GGML_UNUSED(src1); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index db9f6a165..1746b0732 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #if CUDART_VERSION < 11020 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 3205534d6..c905b15d7 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" @@ -121,6 +122,8 @@ #define __has_builtin(x) 0 #endif +typedef hip_bfloat16 nv_bfloat16; + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1604b8229..6cc1b69ee 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F @@ -132,3 +133,5 @@ #define cudaKernelNodeParams musaKernelNodeParams #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamEndCapture musaStreamEndCapture + +typedef mt_bfloat16 nv_bfloat16; From 5047dd3546951dea3d65c02257d06c46c8662338 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 10:52:01 +0200 Subject: [PATCH 10/25] llama : use _impl suffix instead of _internal (#11060) ggml-ci --- src/llama-quant.cpp | 20 ++++++++++---------- src/llama.cpp | 16 ++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 42974f8f1..104f90343 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -22,7 +22,7 @@ static void zeros(std::ofstream & file, size_t n) { } } -struct quantize_state_internal { +struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -43,13 +43,13 @@ struct quantize_state_internal { // used to figure out if a model shares tok_embd with the output weight bool has_output = false; - quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params) + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) : model(model) , params(params) {} }; -static void llama_tensor_dequantize_internal( +static void llama_tensor_dequantize_impl( struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread ) { @@ -121,7 +121,7 @@ static void llama_tensor_dequantize_internal( workers.clear(); } -static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); // TODO: avoid hardcoded tensor names - use the TN_* constants @@ -410,7 +410,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return new_type; } -static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { +static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { if (nthread < 2) { // single-thread size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); @@ -464,7 +464,7 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa return new_size; } -static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -534,7 +534,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llm_load_hparams(ml, model); llm_load_stats (ml, model); - struct quantize_state_internal qs(model, params); + struct quantize_state_impl qs(model, params); if (params->only_copy) { ftype = model.ftype; @@ -837,7 +837,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { - llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread); + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); f32_data = (float *) f32_conv_buf.data(); } @@ -866,7 +866,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); } @@ -919,7 +919,7 @@ uint32_t llama_model_quantize( const char * fname_out, const llama_model_quantize_params * params) { try { - llama_model_quantize_internal(fname_inp, fname_out, params); + llama_model_quantize_impl(fname_inp, fname_out, params); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); return 1; diff --git a/src/llama.cpp b/src/llama.cpp index ea78ea487..4a6798f41 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10717,7 +10717,7 @@ static enum ggml_status llama_graph_compute( // return positive int on warning // return negative int on error // -static int llama_decode_internal( +static int llama_decode_impl( llama_context & lctx, llama_batch inp_batch) { @@ -11052,7 +11052,7 @@ static int llama_decode_internal( // return positive int on warning // return negative int on error // -static int llama_encode_internal( +static int llama_encode_impl( llama_context & lctx, llama_batch inp_batch) { @@ -11234,7 +11234,7 @@ static int llama_encode_internal( } // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache -static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { +static void llama_kv_cache_defrag_impl(struct llama_context & lctx) { auto & kv_self = lctx.kv_self; const auto & hparams = lctx.model.hparams; @@ -11454,7 +11454,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0); } -static void llama_kv_cache_update_internal(struct llama_context & lctx) { +static void llama_kv_cache_update_impl(struct llama_context & lctx) { bool need_reserve = false; if (lctx.kv_self.has_shift) { @@ -11490,7 +11490,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { // defragment the KV cache if needed if (lctx.kv_self.do_defrag) { - llama_kv_cache_defrag_internal(lctx); + llama_kv_cache_defrag_impl(lctx); need_reserve = true; @@ -12191,7 +12191,7 @@ void llama_kv_cache_defrag(struct llama_context * ctx) { } void llama_kv_cache_update(struct llama_context * ctx) { - llama_kv_cache_update_internal(*ctx); + llama_kv_cache_update_impl(*ctx); } bool llama_kv_cache_can_shift(struct llama_context * ctx) { @@ -12203,7 +12203,7 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_encode_internal(*ctx, batch); + const int ret = llama_encode_impl(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -12214,7 +12214,7 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_decode_internal(*ctx, batch); + const int ret = llama_decode_impl(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } From 727368c60f2ebf2d6a7473a4a9f80957ab063a8e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 10:52:15 +0200 Subject: [PATCH 11/25] llama : use LLAMA_TOKEN_NULL (#11062) ggml-ci --- common/common.cpp | 2 +- common/ngram-cache.cpp | 24 +++++++------- common/ngram-cache.h | 4 +-- examples/batched/batched.cpp | 2 +- .../convert-llama2c-to-ggml.cpp | 4 +-- examples/main/main.cpp | 4 +-- examples/server/utils.hpp | 2 +- include/llama.h | 1 - src/llama-model.cpp | 32 +++++++++---------- src/llama-sampling.cpp | 8 ++--- src/llama-vocab.cpp | 24 +++++++------- 11 files changed, 53 insertions(+), 54 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 4bb140ee2..d6a7ab753 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -982,7 +982,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; } tmp.clear(); diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index a9dfb6714..a057ae45f 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -65,13 +65,13 @@ constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66}; static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) { common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); if (part_static_it == nc_static.end()) { - return -1; + return LLAMA_TOKEN_NULL; } const common_ngram_cache_part part_static = part_static_it->second; int max_count_static = 0; int sum_count_static = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_static : part_static) { const llama_token token = token_count_static.first; @@ -85,10 +85,10 @@ static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram } if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { - return -1; + return LLAMA_TOKEN_NULL; } if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) { - return -1; + return LLAMA_TOKEN_NULL; } return max_token; } @@ -98,9 +98,9 @@ static llama_token try_draft( common_ngram_cache & nc_primary, const std::vector & ngrams_primary, common_ngram_cache_part & part_static, const int * min_sample_size, const int * min_percent) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; - for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) { + for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) { const common_ngram ngram_primary = ngrams_primary[i]; common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); @@ -112,7 +112,7 @@ static llama_token try_draft( int max_count_primary = 0; int max_count_static = 0; int sum_count_primary = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_primary : part_primary) { const llama_token token = token_count_primary.first; @@ -154,7 +154,7 @@ void common_ngram_cache_draft( } while ((int) draft.size()-1 < n_draft) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; common_ngram ngram_static; @@ -177,17 +177,17 @@ void common_ngram_cache_draft( } ngrams_cd.push_back(ngram_cd); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_static, ngram_static); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { break; } diff --git a/common/ngram-cache.h b/common/ngram-cache.h index 09c2b0319..dfe012abe 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -17,13 +17,13 @@ struct common_ngram { common_ngram() { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = -1; + tokens[i] = LLAMA_TOKEN_NULL; } } common_ngram(const llama_token * input, const int ngram_size) { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = i < ngram_size ? input[i] : -1; + tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL; } } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index e2e01f2d5..2e25b62f6 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -120,7 +120,7 @@ int main(int argc, char ** argv) { } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = llama_token_bos(model); } diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 736035d78..9c3a0c367 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -689,8 +689,8 @@ static void save_as_llama_model( gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID); - gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, -1); - gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, -1); + gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL); + gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL); gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx); gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b5e477f5b..aaee47e32 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -494,7 +494,7 @@ int main(int argc, char ** argv) { } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = llama_token_bos(model); } @@ -831,7 +831,7 @@ int main(int argc, char ** argv) { // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { llama_token eot = llama_token_eot(model); - embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot); + embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot); need_insert_eot = false; } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index dc6e6e67e..ad130d490 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -507,7 +507,7 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == -1 ? "" : common_token_to_piece(ctx, token); + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) diff --git a/include/llama.h b/include/llama.h index a0d5ba5dd..0f619aa19 100644 --- a/include/llama.h +++ b/include/llama.h @@ -34,7 +34,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -// TODO: use everywhere in the implementation #define LLAMA_TOKEN_NULL -1 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 405e0528f..22596499a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1923,24 +1923,24 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } - if (vocab.special_eom_id != -1) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, vocab.special_eom_id, vocab.id_to_token[vocab.special_eom_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } - if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } + if (vocab.special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + if (vocab.special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, vocab.special_eom_id, vocab.id_to_token[vocab.special_eom_id].text.c_str() ); } + if (vocab.special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.special_cls_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } + if (vocab.special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } - if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); } - if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); } - if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); } - if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); } - if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); } - if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); } + if (vocab.special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); } + if (vocab.special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); } + if (vocab.special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); } + if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); } + if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); } + if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); } for (const auto & id : vocab.special_eog_ids) { LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() ); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 69cea2f14..ef5a576cc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -257,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) for (int i = 0; i < (int)cur_p->size; ++i) { const float val = cur_p->data[i].logit; int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); + ib = std::max(0, std::min(nbuckets - 1, ib)); bucket_idx[i] = ib; ++histo[ib]; } @@ -280,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) for (int i = 0; i < (int)cur_p->size; ++i) { int j = bucket_idx[i]; if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i]; + *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; } } ptr = tmp_tokens.data(); int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { + for (int j = nbuckets - 1; j > ib; --j) { std::sort(ptr, ptr + histo[j], comp); ptr += histo[j]; ndone += histo[j]; @@ -1832,7 +1832,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); if (n > 0) { lt = k; - rt = k+n-1; + rt = k + n - 1; } } else { // If k is inside the current Z-box, consider two cases. diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 3fcfcaa3f..a4c015484 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -497,7 +497,7 @@ struct llm_tokenizer_bpe_session { bool append_bos(std::vector & output) const { if (vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); + GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_bos_id); return true; } @@ -506,7 +506,7 @@ struct llm_tokenizer_bpe_session { bool append_eos(std::vector & output) const { if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); + GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_eos_id); return true; } @@ -1403,7 +1403,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (source == 0) { buffer.erase_after(buffer.before_begin()); } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + buffer.erase_after(std::next(buffer.begin(), (source - 1))); } // repeat for the right side @@ -1417,7 +1417,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (source == 0) { buffer.erase_after(buffer.before_begin()); } else { - buffer.erase_after(std::next(buffer.begin(), (source-1))); + buffer.erase_after(std::next(buffer.begin(), (source - 1))); } break; } @@ -1454,7 +1454,7 @@ std::vector llama_tokenize_internal( bool is_prev_special = true; // prefix with space if first token if (add_special && vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); + GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_bos_id); is_prev_special = true; } @@ -1489,7 +1489,7 @@ std::vector llama_tokenize_internal( } if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); + GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_eos_id); } } break; @@ -1522,7 +1522,7 @@ std::vector llama_tokenize_internal( case LLAMA_VOCAB_TYPE_WPM: { if (add_special) { - GGML_ASSERT(vocab.special_cls_id != -1); + GGML_ASSERT(vocab.special_cls_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_cls_id); } @@ -1542,14 +1542,14 @@ std::vector llama_tokenize_internal( } if (add_special) { - GGML_ASSERT(vocab.special_sep_id != -1); + GGML_ASSERT(vocab.special_sep_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_sep_id); } } break; case LLAMA_VOCAB_TYPE_UGM: { if (add_special && vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); + GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_bos_id); } llm_tokenizer_ugm_session session(vocab); @@ -1574,7 +1574,7 @@ std::vector llama_tokenize_internal( } if (add_special && vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); + GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL); output.push_back(vocab.special_eos_id); } } break; @@ -1642,7 +1642,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla } bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) { - return token != -1 && vocab.special_eog_ids.count(token) > 0; + return token != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(token) > 0; } bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) { @@ -1881,7 +1881,7 @@ int32_t llama_detokenize_impl( } if (remove_special && vocab.tokenizer_add_eos) { - if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) { + if (n_tokens > 0 && tokens[n_tokens - 1] == vocab.special_eos_id) { n_tokens--; } } From ae2f606bb598b287f5fb69c9fdfc98b86598c6cc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 10:52:38 +0200 Subject: [PATCH 12/25] mmap : fix fileno macro clash (#11076) * mmap : fix fileno macro clash ggml-ci * cont ggml-ci --- src/llama-mmap.cpp | 10 +++++++--- src/llama-mmap.h | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index a99326335..a8cb9439b 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -241,12 +241,16 @@ llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } -int llama_file::fileno() const { +int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); #else return ::fileno(pimpl->fp); #endif +#endif } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } @@ -265,7 +269,7 @@ struct llama_mmap::impl { impl(struct llama_file * file, size_t prefetch, bool numa) { size = file->size(); - int fd = file->fileno(); + int fd = file->file_id(); int flags = MAP_SHARED; if (numa) { prefetch = 0; } #ifdef __linux__ @@ -357,7 +361,7 @@ struct llama_mmap::impl { size = file->size(); - HANDLE hFile = (HANDLE) _get_osfhandle(file->fileno()); + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); diff --git a/src/llama-mmap.h b/src/llama-mmap.h index 6bcddee8c..1da9ecb6b 100644 --- a/src/llama-mmap.h +++ b/src/llama-mmap.h @@ -18,7 +18,7 @@ struct llama_file { size_t tell() const; size_t size() const; - int fileno() const; + int file_id() const; // fileno overload void seek(size_t offset, int whence) const; From 3e6e7a6bc2c4b980a0cf0fcb5cb3b79a965b5f14 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 10:54:25 +0200 Subject: [PATCH 13/25] tokenize : escape the prompt (#11058) * tokenize : escape the prompt * tokenize : update help --- examples/tokenize/tokenize.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index c97e22724..57d9d4312 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -31,6 +31,7 @@ static void print_usage_information(const char * argv0) { printf(" -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); printf(" --stdin read prompt from standard input.\n"); printf(" --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); + printf(" --no-escape do not escape input (such as \\n, \\t, etc.).\n"); printf(" --no-parse-special do not parse control tokens.\n"); printf(" --log-disable disable logs. Makes stderr quiet when loading the model.\n"); printf(" --show-count print the total number of tokens.\n"); @@ -198,6 +199,7 @@ int main(int raw_argc, char ** raw_argv) { // variables where to put any arguments we see. bool printing_ids = false; bool no_bos = false; + bool no_escape = false; bool no_parse_special = false; bool disable_logging = false; bool show_token_count = false; @@ -233,6 +235,9 @@ int main(int raw_argc, char ** raw_argv) { else if (arg == "--no-bos") { no_bos = true; } + else if (arg == "--no-escape") { + no_escape = true; + } else if (arg == "--no-parse-special") { no_parse_special = true; } @@ -363,6 +368,11 @@ int main(int raw_argc, char ** raw_argv) { const bool model_wants_add_bos = llama_add_bos_token(model); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; + const bool escape = !no_escape; + + if (escape) { + string_process_escapes(prompt); + } std::vector tokens; tokens = common_tokenize(model, prompt, add_bos, parse_special); From 47182dd03fe04a4ffda5d7f4c8a109ae0056cf56 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 10:55:18 +0200 Subject: [PATCH 14/25] llama : update llama_model API names (#11063) * llama : deprecate llama_free_model, add llama_model_free ggml-ci * llama : change `llama_load_model_from_file` -> `llama_model_load_from_file` ggml-ci --- common/common.cpp | 14 +++++++------- examples/batched-bench/batched-bench.cpp | 4 ++-- examples/batched/batched.cpp | 4 ++-- examples/gritlm/gritlm.cpp | 4 ++-- examples/llama-bench/llama-bench.cpp | 8 ++++---- examples/llava/llava-cli.cpp | 6 +++--- examples/llava/minicpmv-cli.cpp | 4 ++-- examples/llava/qwen2vl-cli.cpp | 6 +++--- examples/passkey/passkey.cpp | 4 ++-- examples/quantize-stats/quantize-stats.cpp | 8 ++++---- examples/run/run.cpp | 2 +- examples/simple-chat/simple-chat.cpp | 4 ++-- examples/simple/simple.cpp | 4 ++-- examples/tokenize/tokenize.cpp | 4 ++-- include/llama-cpp.h | 2 +- include/llama.h | 13 ++++++++++--- src/llama-model.cpp | 4 ++++ src/llama.cpp | 16 +++++++++++----- tests/test-autorelease.cpp | 4 ++-- tests/test-model-load-cancel.cpp | 2 +- tests/test-tokenizer-0.cpp | 6 +++--- tests/test-tokenizer-1-bpe.cpp | 6 +++--- tests/test-tokenizer-1-spm.cpp | 6 +++--- 23 files changed, 76 insertions(+), 59 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d6a7ab753..4fd36105e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -846,7 +846,7 @@ struct common_init_result common_init_from_params(common_params & params) { } else if (!params.model_url.empty()) { model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); } if (model == NULL) { @@ -873,7 +873,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!ok) { - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -884,7 +884,7 @@ struct common_init_result common_init_from_params(common_params & params) { llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -900,7 +900,7 @@ struct common_init_result common_init_from_params(common_params & params) { const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -913,7 +913,7 @@ struct common_init_result common_init_from_params(common_params & params) { params.control_vector_layer_end); if (err) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -926,7 +926,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -1411,7 +1411,7 @@ struct llama_model * common_load_model_from_url( } } - return llama_load_model_from_file(local_path.c_str(), params); + return llama_model_load_from_file(local_path.c_str(), params); } struct llama_model * common_load_model_from_hf( diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index a3b21ad6b..dd75ff9f1 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -38,7 +38,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); @@ -194,7 +194,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 2e25b62f6..d34b03099 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -41,7 +41,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n" , __func__); @@ -236,7 +236,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 18a945b33..4d2db5624 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -165,7 +165,7 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); @@ -219,7 +219,7 @@ int main(int argc, char * argv[]) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); return 0; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2338ad106..2a0916766 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1526,10 +1526,10 @@ int main(int argc, char ** argv) { // keep the same model between tests when possible if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) { if (lmodel) { - llama_free_model(lmodel); + llama_model_free(lmodel); } - lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams()); + lmodel = llama_model_load_from_file(inst.model.c_str(), inst.to_llama_mparams()); if (lmodel == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str()); return 1; @@ -1540,7 +1540,7 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); - llama_free_model(lmodel); + llama_model_free(lmodel); return 1; } @@ -1626,7 +1626,7 @@ int main(int argc, char ** argv) { ggml_threadpool_free_fn(threadpool); } - llama_free_model(lmodel); + llama_model_free(lmodel); if (p) { p->print_footer(); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 2691c6e6b..27215a42e 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -221,7 +221,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -265,7 +265,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } @@ -323,7 +323,7 @@ int main(int argc, char ** argv) { } } - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index e9cbb51ed..2342bdd09 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -31,7 +31,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -75,7 +75,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index e86a60280..f3e5d66e2 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -310,7 +310,7 @@ static struct llama_model * llava_init(common_params * params) { llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); return NULL; @@ -354,7 +354,7 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } @@ -575,7 +575,7 @@ int main(int argc, char ** argv) { } } - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 09bba708f..ea91f376c 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -63,7 +63,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { LOG_ERR("%s: unable to load model\n" , __func__); @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index ab91d0b40..9bfbb8862 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -309,7 +309,7 @@ int main(int argc, char ** argv) { auto mparams = llama_model_default_params(); mparams.use_mlock = false; - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); @@ -323,7 +323,7 @@ int main(int argc, char ** argv) { if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } @@ -347,7 +347,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: Quantization should be tested with a float model, " "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 1; } included_layers++; @@ -409,7 +409,7 @@ int main(int argc, char ** argv) { llama_free(ctx); - llama_free_model(model); + llama_model_free(model); // report timing { const int64_t t_main_end_us = ggml_time_us(); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 75b817272..c52a7961f 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -664,7 +664,7 @@ class LlamaData { "\r%*s" "\rLoading model", get_terminal_width(), " "); - llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), opt.model_params)); + llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); if (!model) { printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 7f4da666b..d72f5bcdd 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -69,7 +69,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); if (!model) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; @@ -194,7 +194,7 @@ int main(int argc, char ** argv) { } llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 3288c0250..f69117890 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index 57d9d4312..684ca054a 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -338,7 +338,7 @@ int main(int raw_argc, char ** raw_argv) { llama_model_params model_params = llama_model_default_params(); model_params.vocab_only = true; - llama_model * model = llama_load_model_from_file(model_path, model_params); + llama_model * model = llama_model_load_from_file(model_path, model_params); if (!model) { fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); return 1; @@ -408,7 +408,7 @@ int main(int raw_argc, char ** raw_argv) { } // silence valgrind llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 1500cb2fc..11306b17f 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -9,7 +9,7 @@ #include "llama.h" struct llama_model_deleter { - void operator()(llama_model * model) { llama_free_model(model); } + void operator()(llama_model * model) { llama_model_free(model); } }; struct llama_context_deleter { diff --git a/include/llama.h b/include/llama.h index 0f619aa19..0295a51fb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -413,12 +413,19 @@ extern "C" { // Call once at the end of the program - currently only used for MPI LLAMA_API void llama_backend_free(void); - LLAMA_API struct llama_model * llama_load_model_from_file( + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); + + LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - // TODO: rename to llama_model_free - LLAMA_API void llama_free_model(struct llama_model * model); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); // TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_new_context_with_model( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 22596499a..7deb3683b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2009,6 +2009,10 @@ struct llama_model_params llama_model_default_params() { } void llama_free_model(struct llama_model * model) { + llama_model_free(model); +} + +void llama_model_free(struct llama_model * model) { delete model; } diff --git a/src/llama.cpp b/src/llama.cpp index 4a6798f41..7337c34ce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11656,6 +11656,12 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, struct llama_model_params params) { + return llama_model_load_from_file(path_model, params); +} + +struct llama_model * llama_model_load_from_file( + const char * path_model, + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; @@ -11694,7 +11700,7 @@ struct llama_model * llama_load_model_from_file( ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); if (!rpc_reg) { LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__); - llama_free_model(model); + llama_model_free(model); return nullptr; } @@ -11702,7 +11708,7 @@ struct llama_model * llama_load_model_from_file( ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); if (!ggml_backend_rpc_add_device_fn) { LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__); - llama_free_model(model); + llama_model_free(model); return nullptr; } @@ -11712,7 +11718,7 @@ struct llama_model * llama_load_model_from_file( model->devices.push_back(dev); } else { LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str()); - llama_free_model(model); + llama_model_free(model); return nullptr; } } @@ -11744,7 +11750,7 @@ struct llama_model * llama_load_model_from_file( if (params.split_mode == LLAMA_SPLIT_MODE_NONE) { if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) { LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size()); - llama_free_model(model); + llama_model_free(model); return nullptr; } ggml_backend_dev_t main_gpu = model->devices[params.main_gpu]; @@ -11767,7 +11773,7 @@ struct llama_model * llama_load_model_from_file( LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); } - llama_free_model(model); + llama_model_free(model); return nullptr; } diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp index 57fa00011..ba084a91a 100644 --- a/tests/test-autorelease.cpp +++ b/tests/test-autorelease.cpp @@ -13,10 +13,10 @@ int main(int argc, char ** argv) { std::thread([&model_path]() { llama_backend_init(); - auto * model = llama_load_model_from_file(model_path, llama_model_default_params()); + auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); auto * ctx = llama_new_context_with_model(model, llama_context_default_params()); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); }).join(); diff --git a/tests/test-model-load-cancel.cpp b/tests/test-model-load-cancel.cpp index 858535c3c..9095826fa 100644 --- a/tests/test-model-load-cancel.cpp +++ b/tests/test-model-load-cancel.cpp @@ -21,7 +21,7 @@ int main(int argc, char *argv[] ) { (void) ctx; return progress > 0.50; }; - auto * model = llama_load_model_from_file(model_path, params); + auto * model = llama_model_load_from_file(model_path, params); llama_backend_free(); return model == nullptr ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 0af85f002..121c2c60c 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -152,7 +152,7 @@ int main(int argc, char **argv) { mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), mparams); + model = llama_model_load_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -165,7 +165,7 @@ int main(int argc, char **argv) { if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } @@ -300,7 +300,7 @@ int main(int argc, char **argv) { fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); } - llama_free_model(model); + llama_model_free(model); llama_free(ctx); llama_backend_free(); diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index 0ff7fc833..5718fab04 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -46,7 +46,7 @@ int main(int argc, char **argv) { mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), mparams); + model = llama_model_load_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -59,7 +59,7 @@ int main(int argc, char **argv) { if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } @@ -143,7 +143,7 @@ int main(int argc, char **argv) { } } - llama_free_model(model); + llama_model_free(model); llama_free(ctx); llama_backend_free(); diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp index 9b0716a43..ac05387c9 100644 --- a/tests/test-tokenizer-1-spm.cpp +++ b/tests/test-tokenizer-1-spm.cpp @@ -34,7 +34,7 @@ int main(int argc, char ** argv) { mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), mparams); + model = llama_model_load_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -47,7 +47,7 @@ int main(int argc, char ** argv) { if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { } } - llama_free_model(model); + llama_model_free(model); llama_free(ctx); llama_backend_free(); From 6369f867a410416239d9f20ec27c2b1d6a9fee52 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 6 Jan 2025 10:28:17 +0100 Subject: [PATCH 15/25] llama : rename missed batch params/vars to ubatch (#10059) This commit renames the `batch` parameter to `ubatch` in the `llama_kv_cache_find_slot`, `llm_build_inp_embd`, and `llm_build_mamba` functions. The motivation for this is that this should have been done as part of Commit 19d900a7565b8f6b0a708836a57d26966cb9efe2 ("llama : rename batch to ubatch (#9950)") but for some reason I missed these functions in that commit and only noticed them now (sorry). --- src/llama-kv-cache.cpp | 32 ++++++++++++++++---------------- src/llama.cpp | 18 +++++++++--------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 53379253a..90b6c56ed 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -119,10 +119,10 @@ bool llama_kv_cache_init( struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { - const uint32_t n_tokens = batch.n_tokens; - const uint32_t n_seqs = batch.n_seqs; - const uint32_t n_seq_tokens = batch.n_seq_tokens; + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), @@ -130,16 +130,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs); int32_t min = cache.size - 1; int32_t max = 0; // everything should fit if all seq_ids are smaller than the max for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = batch.n_seq_id[s]; + const uint32_t n_seq_id = ubatch.n_seq_id[s]; for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + const llama_seq_id seq_id = ubatch.seq_id[s][j]; if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { // too big seq_id @@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // find usable cell range for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; llama_kv_cell & seq_meta = cache.cells[seq_id]; bool has_cell = false; if (seq_meta.tail >= 0) { @@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { int32_t dst_id = s + min; - int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & src_cell = cache.cells[src_id]; @@ -258,7 +258,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; int32_t cell_id = s + min; llama_kv_cell & cell = cache.cells[cell_id]; @@ -266,12 +266,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); } cell.pos = last_pos; cell.seq_id.clear(); - for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; cell.seq_id.insert(seq_id); cache.cells[seq_id].tail = cell_id; } @@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = batch.pos[k]; + cache.cells[cache.head + k].pos = ubatch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); } } } diff --git a/src/llama.cpp b/src/llama.cpp index 7337c34ce..60728e5bb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2540,21 +2540,21 @@ static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, struct llama_context & lctx, const llama_hparams & hparams, - const llama_ubatch & batch, + const llama_ubatch & ubatch, struct ggml_tensor * tok_embd, const llm_build_cb & cb) { const int64_t n_embd = hparams.n_embd; struct ggml_tensor * inpL; - if (batch.token) { - lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); + if (ubatch.token) { + lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens); cb(lctx.inp_tokens, "inp_tokens", -1); ggml_set_input(lctx.inp_tokens); inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); } @@ -3149,7 +3149,7 @@ static struct ggml_tensor * llm_build_copy_mask_state( static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, - const llama_ubatch & batch, + const llama_ubatch & ubatch, struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, @@ -3165,17 +3165,17 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; - const int64_t n_seqs = batch.n_seqs; + const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; // Use the same RMS norm as the final layer norm const float norm_rms_eps = hparams.f_norm_rms_eps; - const int64_t n_seq_tokens = batch.n_seq_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(batch.equal_seqs); - GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); struct ggml_tensor * conv_states_all = kv.k_l[il]; struct ggml_tensor * ssm_states_all = kv.v_l[il]; From 96a1dc27c3f09bf1ed83a26292d571795bcf27fa Mon Sep 17 00:00:00 2001 From: Asghar Ghorbani Date: Mon, 6 Jan 2025 12:21:46 +0100 Subject: [PATCH 16/25] llama : prevent system info string accumulation across calls (#11101) --- src/llama.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 60728e5bb..c162c31a6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12458,6 +12458,8 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int const char * llama_print_system_info(void) { static std::string s; + s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls. + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { auto * reg = ggml_backend_reg_get(i); From 09186fabbe05236f2b9446ba6c643cb737540d10 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 6 Jan 2025 13:41:12 +0100 Subject: [PATCH 17/25] llama : remove check flash_attn with lora (#11104) --- src/llama.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c162c31a6..ebd6e3b29 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -11519,13 +11519,7 @@ int32_t llama_lora_adapter_set( struct llama_context * ctx, struct llama_lora_adapter * adapter, float scale) { - if (ctx->cparams.flash_attn) { - LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__); - return -1; - } - ctx->lora_adapters[adapter] = scale; - return 0; } From e6e7c75d94adf4d39e846d30807c531ff22865e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 15:36:08 +0200 Subject: [PATCH 18/25] server : fix extra BOS in infill endpoint (#11106) * server : fix extra BOS in infill endpoing ggml-ci * server : update infill tests --- examples/server/server.cpp | 2 +- examples/server/tests/unit/test_infill.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c2e62ba69..127323e77 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3797,7 +3797,7 @@ int main(int argc, char ** argv) { data["input_extra"] = input_extra; // default to empty array if it's not exist std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); data["prompt"] = format_infill( ctx_server.ctx, diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py index ad4b8192a..10554db0f 100644 --- a/examples/server/tests/unit/test_infill.py +++ b/examples/server/tests/unit/test_infill.py @@ -18,7 +18,7 @@ def test_infill_without_input_extra(): "input_suffix": "}\n", }) assert res.status_code == 200 - assert match_regex("(Ann|small|shiny)+", res.body["content"]) + assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) def test_infill_with_input_extra(): From 96be8c32649378a23031630a48c440f3a5d0839b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 6 Jan 2025 16:34:49 +0100 Subject: [PATCH 19/25] github : add cmd line field to bug report (#11090) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * github : cmd line to bug report * codeowners : (@ngxson) only watch dockerfile * Apply suggestions from code review [no ci] Co-authored-by: Johannes Gäßler * rm cmd in log output [no ci] * rm 2 [no ci] * no need backticks [no ci] --------- Co-authored-by: Johannes Gäßler --- .github/ISSUE_TEMPLATE/010-bug-compilation.yml | 12 +++++++++++- .github/ISSUE_TEMPLATE/019-bug-misc.yml | 12 +++++++++++- CODEOWNERS | 2 +- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml index f10b3a2b2..b85bf5741 100644 --- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -65,12 +65,22 @@ body: If possible, please do a git bisect and identify the exact commit that introduced the bug. validations: required: false + - type: textarea + id: command + attributes: + label: Compile command + description: > + Please provide the exact command you used to compile llama.cpp. For example: `cmake -B ...`. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true - type: textarea id: logs attributes: label: Relevant log output description: > - Please copy and paste any relevant log output, including the command that you entered and any generated text. + Please copy and paste any relevant log output, including any generated text. This will be automatically formatted into code, so no need for backticks. render: shell validations: diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml index d157ea307..1904e31fd 100644 --- a/.github/ISSUE_TEMPLATE/019-bug-misc.yml +++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml @@ -52,6 +52,16 @@ body: - Other (Please specify in the next section) validations: required: false + - type: textarea + id: command + attributes: + label: Command line + description: > + Please provide the exact commands you entered, if applicable. For example: `llama-server -m ... -c ...`, `llama-cli -m ...`, etc. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false - type: textarea id: info attributes: @@ -74,7 +84,7 @@ body: attributes: label: Relevant log output description: > - If applicable, please copy and paste any relevant log output, including the command that you entered and any generated text. + If applicable, please copy and paste any relevant log output, including any generated text. This will be automatically formatted into code, so no need for backticks. render: shell validations: diff --git a/CODEOWNERS b/CODEOWNERS index adeba5395..c9fa34761 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,5 @@ # collaborators can optionally add themselves here to indicate their availability for reviewing related PRs /ci/ @ggerganov -/.devops/ @ngxson +/.devops/*.Dockerfile @ngxson /examples/server/ @ngxson From ecebbd292d741ac084cf248146b2cfb17002aa1d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 6 Jan 2025 17:52:35 +0200 Subject: [PATCH 20/25] llama : remove unused headers (#11109) ggml-ci --- src/llama.cpp | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ebd6e3b29..8ea6686c9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8,7 +8,6 @@ #include "llama-kv-cache.h" #include "llama-model-loader.h" #include "llama-model.h" -#include "llama-quant.h" #include "ggml.h" #include "ggml-alloc.h" @@ -18,12 +17,8 @@ #include #include #include -#include #include -#include -#include #include -#include #include #include #include @@ -31,10 +26,7 @@ #include #include #include -#include #include -#include -#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -12434,16 +12426,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, return 0; } -int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { +int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) { std::string str_split_path(split_path); char postfix[32]; snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count); std::string str_postfix(postfix); - // check if dest ends with postfix + // check if split_prefix ends with postfix int size_prefix = str_split_path.size() - str_postfix.size(); if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { - snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); + snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); return size_prefix; } From dc7cef9f373f2a24b851f0df7a618c5209e593fa Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 6 Jan 2025 22:45:28 +0000 Subject: [PATCH 21/25] llama-run : fix context size (#11094) Set `n_ctx` equal to `n_batch` in `Opt` class. Now context size is a more reasonable 2048. Signed-off-by: Eric Curtin --- examples/run/run.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index c52a7961f..2888fcfed 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -83,6 +83,7 @@ class Opt { } ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; + ctx_params.n_ctx = ctx_params.n_batch; model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; temperature = temperature >= 0 ? temperature : temperature_default; From c0d6f790d07aa78be15584ec394ac20739ade93b Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 7 Jan 2025 11:56:07 +0530 Subject: [PATCH 22/25] SYCL: Use get_multi_ptr instead of deprecated get_pointer in wkv6 (#11087) * SYCL: Use get_multi_ptr instead of deprecated get_pointer in wkv6 * Revert "SYCL: Use get_multi_ptr instead of deprecated get_pointer in wkv6" This reverts commit f62dc45f318e48d375e7734b34cbddee81deed52. * Reland: Use get_multi_ptr instead of deprecated get_pointer in wkv6 --- ggml/src/ggml-sycl/wkv6.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp index 75ddfb86a..105db6f03 100644 --- a/ggml/src/ggml-sycl/wkv6.cpp +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -131,7 +131,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s [=](sycl::nd_item<3> item_ct1) { rwkv_wkv_f32_kernel( B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, - item_ct1, shared_mem_acc.get_pointer() + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() ); }); }); From a4dd490069a66ae56b42127048f06757fc4de4f7 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 7 Jan 2025 08:37:02 +0200 Subject: [PATCH 23/25] rpc : code cleanup (#11107) Remove duplicated macros, use GGML_LOG_ERROR for errors --- ggml/src/ggml-rpc/ggml-rpc.cpp | 49 ++++++++++++++-------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 2213aba9f..63da2b86b 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -27,15 +27,6 @@ #endif #include -#define UNUSED GGML_UNUSED - -#define GGML_DEBUG 0 -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - #ifdef _WIN32 typedef SOCKET sockfd_t; using ssize_t = __int64; @@ -411,7 +402,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { initialized = true; } #else - UNUSED(initialized); + GGML_UNUSED(initialized); #endif auto sock = socket_connect(host.c_str(), port); if (sock == nullptr) { @@ -640,7 +631,7 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) { } static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { - UNUSED(backend); + GGML_UNUSED(backend); // this is no-op because we don't have any async operations } @@ -850,7 +841,7 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); } } @@ -872,7 +863,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } void * base = ggml_backend_buffer_get_base(buffer); @@ -884,7 +875,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } ggml_backend_buffer_free(buffer); @@ -896,7 +887,7 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } ggml_backend_buffer_clear(buffer, request.value); @@ -952,7 +943,7 @@ bool rpc_server::set_tensor(const std::vector & input) { struct ggml_context * ctx = ggml_init(params); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); ggml_free(ctx); return false; } @@ -1017,7 +1008,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< struct ggml_context * ctx = ggml_init(params); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); ggml_free(ctx); return false; } @@ -1051,7 +1042,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); ggml_free(ctx); return false; } @@ -1385,14 +1376,14 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - UNUSED(dev); + GGML_UNUSED(dev); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { // TODO: obtain value from the server return GGML_BACKEND_DEVICE_TYPE_GPU; - UNUSED(dev); + GGML_UNUSED(dev); } static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { @@ -1413,7 +1404,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const return ggml_backend_rpc_init(ctx->endpoint.c_str()); - UNUSED(params); + GGML_UNUSED(params); } static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -1421,12 +1412,12 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); - UNUSED(dev); + GGML_UNUSED(dev); } static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - UNUSED(dev); - UNUSED(op); + GGML_UNUSED(dev); + GGML_UNUSED(op); //TODO: call the remote backend and cache the results return true; } @@ -1463,20 +1454,20 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { return "RPC"; - UNUSED(reg); + GGML_UNUSED(reg); } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { return 0; - UNUSED(reg); + GGML_UNUSED(reg); } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - UNUSED(reg); - UNUSED(index); + GGML_UNUSED(reg); + GGML_UNUSED(index); } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { @@ -1485,7 +1476,7 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch } return NULL; - UNUSED(reg); + GGML_UNUSED(reg); } static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { From a3d50bc022bedd6c7754c24749a1fef4d2d60c7c Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Tue, 7 Jan 2025 12:38:05 +0100 Subject: [PATCH 24/25] ggml-backend : only offload from host buffers (#11120) --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index e2d6c4056..d034f8b7f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -761,7 +761,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st } // skip ROPE since the rope freqs tensor is too small to choose a backend based on it // not an ideal solution - if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && ggml_backend_buffer_is_host(src->buffer)) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op if (src_backend_id == sched->n_backends - 1) { From 017cc5f446863316d05522a87f25ec48713a9492 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Tue, 7 Jan 2025 16:11:57 +0100 Subject: [PATCH 25/25] ggml-backend : only offload from host buffers (fix) (#11124) --- ggml/src/ggml-backend.cpp | 4 ++-- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d034f8b7f..dba7be33b 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -761,10 +761,10 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st } // skip ROPE since the rope freqs tensor is too small to choose a backend based on it // not an ideal solution - if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && ggml_backend_buffer_is_host(src->buffer)) { + if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1) { + if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 622c63f1f..b311a5b1c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -4169,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g buffer->buft = buft; buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor; buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; return buffer; }