Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/full-cuda.Dockerfile
#	.devops/full-rocm.Dockerfile
#	.devops/llama-cli-cuda.Dockerfile
#	.devops/llama-cli-rocm.Dockerfile
#	.devops/llama-cli-vulkan.Dockerfile
#	.devops/llama-cpp-cuda.srpm.spec
#	.devops/llama-server-cuda.Dockerfile
#	.devops/llama-server-rocm.Dockerfile
#	.devops/llama-server-vulkan.Dockerfile
#	.github/workflows/build.yml
#	.github/workflows/docker.yml
#	CMakeLists.txt
#	Makefile
#	README.md
#	examples/llama.android/llama/src/main/cpp/CMakeLists.txt
#	flake.lock
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	grammars/README.md
#	scripts/sync-ggml-am.sh
#	scripts/sync-ggml.last
#	tests/test-chat-template.cpp
#	tests/test-grammar-integration.cpp
#	tests/test-json-schema-to-grammar.cpp
This commit is contained in:
Concedo 2024-06-30 10:59:42 +08:00
commit 02f92f6ecc
22 changed files with 632 additions and 182 deletions

View file

@ -241,6 +241,7 @@ enum llm_arch {
LLM_ARCH_INTERNLM2,
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@ -281,6 +282,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_INTERNLM2, "internlm2" },
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@ -502,10 +504,12 @@ enum llm_tensor {
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_NORM_2,
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_POST_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
@ -1028,6 +1032,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GEMMA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -2066,6 +2088,9 @@ enum e_model {
MODEL_8x22B,
MODEL_16x12B,
MODEL_10B_128x3_66B,
MODEL_57B_A14B,
MODEL_9B,
MODEL_27B,
};
static const size_t kiB = 1024;
@ -2242,6 +2267,7 @@ struct llama_layer {
struct ggml_tensor * attn_q_a_norm;
struct ggml_tensor * attn_kv_a_norm;
struct ggml_tensor * attn_sub_norm;
struct ggml_tensor * attn_post_norm;
struct ggml_tensor * ffn_sub_norm;
// attention
@ -2265,6 +2291,7 @@ struct llama_layer {
// normalization
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
struct ggml_tensor * ffn_post_norm;
struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b;
struct ggml_tensor * ffn_norm_exps;
@ -4320,6 +4347,9 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_8x22B: return "8x22B";
case MODEL_16x12B: return "16x12B";
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
case MODEL_57B_A14B: return "57B.A14B";
case MODEL_9B: return "9B";
case MODEL_27B: return "27B";
default: return "?B";
}
}
@ -4641,6 +4671,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_A2_7B; break;
case 28: model.type = e_model::MODEL_57B_A14B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
@ -4721,6 +4752,16 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 42: model.type = e_model::MODEL_9B; break;
case 46: model.type = e_model::MODEL_27B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -5130,6 +5171,9 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "poro-chat") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
} else if (
tokenizer_pre == "viking") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
@ -5224,10 +5268,10 @@ static void llm_load_vocab(
if (gen_name.find("code") != std::string::npos) {
if (model.arch == LLM_ARCH_LLAMA
&& 32010 < vocab.id_to_token.size()
&& vocab.id_to_token[32007].text == "<PRE>"
&& vocab.id_to_token[32008].text == "<SUF>"
&& vocab.id_to_token[32009].text == "<MID>"
&& vocab.id_to_token[32010].text == "<EOT>") {
&& vocab.id_to_token[32007].text.find("<PRE>") != std::string::npos
&& vocab.id_to_token[32008].text.find("<SUF>") != std::string::npos
&& vocab.id_to_token[32009].text.find("<MID>") != std::string::npos
&& vocab.id_to_token[32010].text.find("<EOT>") != std::string::npos) {
vocab.special_prefix_id = 32007;
vocab.special_suffix_id = 32008;
vocab.special_middle_id = 32009;
@ -6585,6 +6629,40 @@ static bool llm_load_tensors(
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
}
} break;
case LLM_ARCH_GEMMA2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
const int64_t n_ff = hparams.n_ff;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
for (uint32_t i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
}
} break;
case LLM_ARCH_STARCODER2:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -10996,6 +11074,125 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_gemma2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
cb(Qcur, "Qcur_scaled", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
cur = llm_build_norm(ctx0, sa_out, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post_norm", -1);
cur = ggml_add(ctx0, cur, sa_out);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -12376,6 +12573,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gemma();
} break;
case LLM_ARCH_GEMMA2:
{
result = llm.build_gemma2();
} break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();
@ -14003,6 +14204,12 @@ struct llm_tokenizer_bpe {
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_VIKING:
regex_exprs = {
"\\p{N}",
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@ -17915,6 +18122,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_GPTNEOX:
return LLAMA_ROPE_TYPE_NEOX;
@ -19752,7 +19960,10 @@ static int32_t llama_chat_apply_template_internal(
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
return tmpl.find(haystack) != std::string::npos;
};
if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
// chatml template
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
@ -19760,16 +19971,16 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
bool support_system_message = tmpl_contains("<<SYS>>") || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
bool space_around_response = tmpl_contains("' ' + eos_token");
// [variant] add BOS inside history
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
// [variant] trim spaces from the input message
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
bool strip_message = tmpl_contains("content.strip()");
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
@ -19795,7 +20006,7 @@ static int32_t llama_chat_apply_template_internal(
}
}
// llama2 templates seem to not care about "add_generation_prompt"
} else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
} else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
// Phi 3
for (auto message : chat) {
std::string role(message->role);
@ -19804,7 +20015,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
} else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
// zephyr template
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
@ -19812,7 +20023,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
} else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
@ -19821,7 +20032,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<s>assistant\n";
}
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
} else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("<start_of_turn>")) {
// google/gemma-7b-it
std::string system_prompt = "";
for (auto message : chat) {
@ -19843,7 +20054,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<start_of_turn>model\n";
}
} else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
} else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
// OrionStarAI/Orion-14B-Chat
std::string system_prompt = "";
for (auto message : chat) {
@ -19863,7 +20074,7 @@ static int32_t llama_chat_apply_template_internal(
ss << message->content << "</s>";
}
}
} else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
} else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
// openchat/openchat-3.5-0106,
for (auto message : chat) {
std::string role(message->role);
@ -19877,13 +20088,13 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "GPT4 Correct Assistant:";
}
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
} else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
// eachadea/vicuna-13b-1.1 (and Orca variant)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
// Orca-Vicuna variant uses a system prefix
if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
ss << "SYSTEM: " << message->content << "\n";
} else {
ss << message->content << "\n\n";
@ -19897,7 +20108,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "ASSISTANT:";
}
} else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
} else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
// deepseek-ai/deepseek-coder-33b-instruct
for (auto message : chat) {
std::string role(message->role);
@ -19912,7 +20123,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "### Response:\n";
}
} else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
} else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
// CohereForAI/c4ai-command-r-plus
for (auto message : chat) {
std::string role(message->role);
@ -19927,7 +20138,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
}
} else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
} else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
// Llama 3
for (auto message : chat) {
std::string role(message->role);
@ -19936,6 +20147,33 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << u8"<用户>";
ss << trim(message->content);
ss << "<AI>";
} else {
ss << trim(message->content);
}
}
} else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
// DeepSeek-V2
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << message->content << "\n\n";
} else if (role == "user") {
ss << "User: " << message->content << "\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << u8"<end▁of▁sentence>";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;