From 3f96aeff394e9b72bbd2fa665c3e023a70ed8648 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 9 May 2025 11:17:51 +0200 Subject: [PATCH 1/5] llama : one-off chat template fix for Mistral-Small-2503 (#13398) * llama : one-off chat template fix for Mistral-Small-2503 * update readme * add mistral-v7-tekken --- src/llama-chat.cpp | 14 ++++++++------ src/llama-chat.h | 1 + src/llama-model.cpp | 8 ++++++++ tools/mtmd/README.md | 2 +- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 46d43c58e..d12743e6b 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, @@ -202,19 +203,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { // Official mistral 'v7' template // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken + const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : ""; for (auto message : chat) { std::string role(message->role); std::string content(message->content); if (role == "system") { - ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; } else if (role == "user") { - ss << "[INST] " << content << "[/INST]"; - } - else { - ss << " " << content << ""; + ss << "[INST]" << trailing_space << content << "[/INST]"; + } else { + ss << trailing_space << content << ""; } } } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 diff --git a/src/llama-chat.h b/src/llama-chat.h index 3f5843466..db24ade21 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -14,6 +14,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN, LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3ca265be8..e8b78c1d0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13387,6 +13387,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { + // one-off fix for very popular models (so we are not flooded with issues) + // do not extend this list unless absolutely necessary + // Mistral-Small-2503 does not have built-in chat template + llama_vocab_pre_type pre_type = model->vocab.get_pre_type(); + if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) { + return "mistral-v7-tekken"; + } + return nullptr; } diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index b97b9e8c5..20e7696ce 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -46,7 +46,7 @@ llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF # Mistral Small 3.1 24B (IQ2_M quantization) -llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7 +llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF ``` ## How it works and what is `mmproj`? From 2189fd3b6327a1d17893694125da8edcf74a6468 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 9 May 2025 11:18:02 +0200 Subject: [PATCH 2/5] mtmd : fix batch_view for m-rope (#13397) * mtmd : fix batch_view for m-rope * nits : fix comment --- tools/mtmd/mtmd.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 5d18e8929..2fecf08a4 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -554,14 +554,19 @@ struct decode_embd_batch { llama_batch get_view(int offset, int n_tokens) { llama_pos * pos_ptr; pos_view.clear(); - pos_view.resize(n_tokens * n_pos_per_embd); + pos_view.reserve(n_tokens * n_pos_per_embd); if (n_pos_per_embd > 1) { // mrope // for example, with layout of src: 1234...1234...1234...1234... // offset 2 will give us dst: 34...34...34...34... for (int i = 0; i < n_pos_per_embd; i++) { - auto src = pos.begin() + i * batch.n_tokens + offset; - pos_view.insert(pos_view.end(), src, src + n_tokens); + // assume n_tokens is less than or equal to batch.n_tokens + // batch.n_tokens is number of **total** tokens + // n_tokens is number of viewed token + size_t src_idx = i * batch.n_tokens + offset; + pos_view.insert(pos_view.end(), + pos.data() + src_idx, + pos.data() + src_idx + n_tokens); } pos_ptr = pos_view.data(); } else { From 0527771dd80bd18479dfaaa0a98be297fc3592bf Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Fri, 9 May 2025 17:25:50 +0800 Subject: [PATCH 3/5] llama-run: add support for downloading models from ModelScope (#13370) Signed-off-by: Xiaodong Ye --- tools/run/README.md | 2 ++ tools/run/run.cpp | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tools/run/README.md b/tools/run/README.md index 89a552079..5fd769b44 100644 --- a/tools/run/README.md +++ b/tools/run/README.md @@ -42,6 +42,8 @@ Examples: llama-run ollama://smollm:135m llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf + llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf + llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf llama-run https://example.com/some-file1.gguf llama-run some-file2.gguf llama-run file://some-file3.gguf diff --git a/tools/run/run.cpp b/tools/run/run.cpp index e63c2aac3..a189ae7fa 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -267,7 +267,7 @@ class Opt { "Commands:\n" " model\n" " Model is a string with an optional prefix of \n" - " huggingface:// (hf://), ollama://, https:// or file://.\n" + " huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n" " If no protocol is specified and a file exists in the specified\n" " path, file:// is assumed, otherwise if a file does not exist in\n" " the specified path, ollama:// is assumed. Models that are being\n" @@ -282,6 +282,9 @@ class Opt { " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" " llama-run " "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" + " llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" + " llama-run " + "modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" " llama-run https://example.com/some-file1.gguf\n" " llama-run some-file2.gguf\n" " llama-run file://some-file3.gguf\n" @@ -689,7 +692,7 @@ class LlamaData { return 0; } - int huggingface_dl(std::string & model, const std::string & bn) { + int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) { // Find the second occurrence of '/' after protocol string size_t pos = model.find('/'); pos = model.find('/', pos + 1); @@ -697,8 +700,6 @@ class LlamaData { std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; std::string url; - std::string model_endpoint = get_model_endpoint(); - if (pos == std::string::npos) { auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/"); hfr = model_name; @@ -720,6 +721,16 @@ class LlamaData { return download(url, bn, true, headers); } + int modelscope_dl(std::string & model, const std::string & bn) { + std::string model_endpoint = "https://modelscope.cn/models/"; + return dl_from_endpoint(model_endpoint, model, bn); + } + + int huggingface_dl(std::string & model, const std::string & bn) { + std::string model_endpoint = get_model_endpoint(); + return dl_from_endpoint(model_endpoint, model, bn); + } + int ollama_dl(std::string & model, const std::string & bn) { const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; if (model.find('/') == std::string::npos) { @@ -837,6 +848,9 @@ class LlamaData { rm_until_substring(model_, "hf.co/"); rm_until_substring(model_, "://"); ret = huggingface_dl(model_, bn); + } else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) { + rm_until_substring(model_, "://"); + ret = modelscope_dl(model_, bn); } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && !string_starts_with(model_, "https://ollama.com/library/")) { ret = download(model_, bn, true); From efb8b47eda78ea8ae570d4fece3953aae499289e Mon Sep 17 00:00:00 2001 From: Bartowski <3266127+bartowski1182@users.noreply.github.com> Date: Fri, 9 May 2025 05:53:58 -0400 Subject: [PATCH 4/5] imatrix : Add --parse-special for enabling parsing of special tokens in imatrix calculation (#13389) * Add --parse-special for enabling parsing of special tokens in imatrix calculation * whitespace --- common/arg.cpp | 7 +++++++ common/common.h | 1 + tools/imatrix/imatrix.cpp | 5 +++-- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 9f87e9910..73a3cfe53 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2627,6 +2627,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.i_chunk = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--parse-special"}, + string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), + [](common_params & params) { + params.parse_special = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"-pps"}, string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), diff --git a/common/common.h b/common/common.h index 907022454..d051d4ec9 100644 --- a/common/common.h +++ b/common/common.h @@ -409,6 +409,7 @@ struct common_params { bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity + bool parse_special = false; // whether to parse special tokens during imatrix tokenization // cvector-generator params int n_pca_batch = 100; diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 2c39278db..81d0404d6 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -24,7 +24,8 @@ static void print_usage(int, char ** argv) { LOG("\n %s \\\n" " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n" " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" - " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]); + " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] \\\n" + " [--parse-special]\n" , argv[0]); LOG("\n"); } @@ -439,7 +440,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); - std::vector tokens = common_tokenize(ctx, params.prompt, true); + std::vector tokens = common_tokenize(ctx, params.prompt, true, params.parse_special); auto tim2 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); From 5c86c9ed3ef1cc7307fdce05f0f0e2e45253cf90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 9 May 2025 12:14:04 +0200 Subject: [PATCH 5/5] CUDA: fix crash on large batch size for MoE models (#13384) --- ggml/src/ggml-cuda/getrows.cu | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index ea8bf6916..963e4d03d 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -10,10 +10,11 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = blockIdx.y * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -94,8 +96,8 @@ static void get_rows_cuda_q( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -127,8 +129,8 @@ static void get_rows_cuda_float( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t);