diff --git a/common/arg.cpp b/common/arg.cpp index d41c66611..36b19538e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2628,6 +2628,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.i_chunk = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--parse-special"}, + string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), + [](common_params & params) { + params.parse_special = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"-pps"}, string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), diff --git a/common/common.h b/common/common.h index dfaaa6026..3b7e74bbd 100644 --- a/common/common.h +++ b/common/common.h @@ -405,6 +405,7 @@ struct common_params { bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity + bool parse_special = false; // whether to parse special tokens during imatrix tokenization // cvector-generator params int n_pca_batch = 100; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index ea8bf6916..963e4d03d 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -10,10 +10,11 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = blockIdx.y * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -94,8 +96,8 @@ static void get_rows_cuda_q( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -127,8 +129,8 @@ static void get_rows_cuda_float( const size_t nb1, const size_t nb2, const size_t nb3, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 46d43c58e..d12743e6b 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, @@ -202,19 +203,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { // Official mistral 'v7' template // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken + const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : ""; for (auto message : chat) { std::string role(message->role); std::string content(message->content); if (role == "system") { - ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; } else if (role == "user") { - ss << "[INST] " << content << "[/INST]"; - } - else { - ss << " " << content << ""; + ss << "[INST]" << trailing_space << content << "[/INST]"; + } else { + ss << trailing_space << content << ""; } } } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 diff --git a/src/llama-chat.h b/src/llama-chat.h index 3f5843466..db24ade21 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -14,6 +14,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN, LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2a5d2abd2..29484cec0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13490,6 +13490,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { + // one-off fix for very popular models (so we are not flooded with issues) + // do not extend this list unless absolutely necessary + // Mistral-Small-2503 does not have built-in chat template + llama_vocab_pre_type pre_type = model->vocab.get_pre_type(); + if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) { + return "mistral-v7-tekken"; + } + return nullptr; } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index e3828d62a..4801a94aa 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -554,14 +554,19 @@ struct decode_embd_batch { llama_batch get_view(int offset, int n_tokens) { llama_pos * pos_ptr; pos_view.clear(); - pos_view.resize(n_tokens * n_pos_per_embd); + pos_view.reserve(n_tokens * n_pos_per_embd); if (n_pos_per_embd > 1) { // mrope // for example, with layout of src: 1234...1234...1234...1234... // offset 2 will give us dst: 34...34...34...34... for (int i = 0; i < n_pos_per_embd; i++) { - auto src = pos.begin() + i * batch.n_tokens + offset; - pos_view.insert(pos_view.end(), src, src + n_tokens); + // assume n_tokens is less than or equal to batch.n_tokens + // batch.n_tokens is number of **total** tokens + // n_tokens is number of viewed token + size_t src_idx = i * batch.n_tokens + offset; + pos_view.insert(pos_view.end(), + pos.data() + src_idx, + pos.data() + src_idx + n_tokens); } pos_ptr = pos_view.data(); } else {