Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	docs/development/HOWTO-add-model.md
#	ggml/src/ggml-sycl/rope.cpp
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-07-09 19:25:28 +08:00
commit b8c1fc7c9e
30 changed files with 1784 additions and 263 deletions

View file

@ -2736,6 +2736,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.public_path = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
add_opt(common_arg(
{"--api-prefix"}, "PREFIX",
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
[](common_params & params, const std::string & value) {
params.api_prefix = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
add_opt(common_arg(
{"--no-webui"},
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),

View file

@ -366,6 +366,7 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string api_prefix = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;

View file

@ -815,6 +815,24 @@ class TextModel(ModelBase):
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
# ref: https://huggingface.co/skt/A.X-4.0
res = "a.x-4.0"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if res is None:
logger.warning("\n")
@ -4896,15 +4914,17 @@ class Mamba2Model(TextModel):
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
# skip the assertion for FalconH1 Model
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
@ -4943,7 +4963,7 @@ class Mamba2Model(TextModel):
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
@ -6535,6 +6555,277 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1
def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]
# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)
# Use Llama conversion for attention
self._transformer_model_class = LlamaModel
# n_group and d_inner are used during reshape_tensors for mamaba2
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["mamba_d_ssm"])
self.d_head = self.find_hparam(["d_head"])
# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components
# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)
def set_vocab(self):
self._set_vocab_gpt2()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors = list(super().modify_tensors(data_torch, name, bid))
tensor = tensors[0][1]
if "down_proj" in name:
tensor = tensor * self.mlp_multipliers[1]
elif "gate_proj" in name:
tensor = tensor * self.mlp_multipliers[0]
elif "k_proj" in name:
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
elif "q_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "v_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "o_proj" in name:
tensor = tensor * self.attention_out_multiplier
elif "out_proj" in name:
tensor = tensor * self.ssm_out_multiplier
elif "in_proj" in name:
tensor = tensor * self.ssm_in_multiplier
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
elif "lm_head" in name:
tensor = tensor * self.hparams["lm_head_multiplier"]
elif "embed_tokens" in name:
tensor = tensor * self.hparams["embedding_multiplier"]
elif "mamba.norm" in name:
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
tensors = [(tensors[0][0], tensor)]
return tensors
def set_gguf_parameters(self):
super().set_gguf_parameters()
## General Params ##
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
# Override some Mamba2 defaults
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])
## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# For handling tied embeddings
self._tok_embd = None
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
# 1. Get the pre-tokenizer identifier hash
tokpre = self.get_vocab_base_pre(tokenizer)
# 2. Reverse-engineer the merges list from mergeable_ranks
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2: # todo this is an assert in Qwen, why?
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
# 3. Generate the tokens and toktypes lists
vocab_size = self.hparams["vocab_size"]
assert tokenizer.vocab_size == vocab_size
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)
# 4. Write all vocab-related fields to the GGUF writer
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)
# 5. Add special tokens and chat templates
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# FIX for BOS token: Overwrite incorrect id read from config.json
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
moe_intermediate_size = hparams["moe_intermediate_size"]
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
moe_topk = hparams["moe_topk"]
assert all(topk == moe_topk[0] for topk in moe_topk)
self.gguf_writer.add_expert_used_count(moe_topk[0])
moe_shared_expert = hparams["num_shared_expert"]
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
# Rope
rope_scaling = hparams.get("rope_scaling", {})
if rope_scaling.get("type") == "dynamic":
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "model.embed_tokens.weight":
self._tok_embd = data_torch.clone()
if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []
if name.find("mlp.experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
def set_vocab(self):
super().set_vocab()
# remove unsupported array slicing in chat template
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
if tokenizer.chat_template is not None:
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
###### CONVERSION LOGIC ######

View file

@ -128,6 +128,7 @@ models = [
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@ -137,6 +138,12 @@ pre_computed_hashes = [
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
]

View file

@ -501,7 +501,7 @@ extern "C" {
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_POOL_2D_BACK,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_UPSCALE,
GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_ROLL,

View file

@ -190,7 +190,10 @@ static const char * cu_get_error_str(CUresult err) {
} \
} while (0)
#else
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
GGML_UNUSED(nbytes); \
} while (0)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)

View file

@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

View file

@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

View file

@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_BF16:
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@ -210,6 +214,10 @@ void get_rows_cuda(
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_F16:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

View file

@ -3205,6 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -3378,7 +3380,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:

View file

@ -50,21 +50,19 @@ static __global__ void rope_norm(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
if (i0 >= n_dims) {
dst[idst + 0] = x[ix + 0];
dst[idst + 1] = x[ix + 1];
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -94,21 +92,19 @@ static __global__ void rope_neox(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -138,21 +134,19 @@ static __global__ void rope_multi(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row_dst*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
return;
}
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;

View file

@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
}
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
int y0_src = (int)floorf(y_src_f);
int y1_src = y0_src + 1;
y0_src = max(0, min(y0_src, ne01_src - 1));
y1_src = max(0, min(y1_src, ne01_src - 1));
float dy = y_src_f - (float)y0_src;
dy = max(0.0f, min(dy, 1.0f));
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
int x0_src = (int)floorf(x_src_f);
int x1_src = x0_src + 1;
x0_src = max(0, min(x0_src, ne00_src - 1));
x1_src = max(0, min(x1_src, ne00_src - 1));
float dx = x_src_f - (float)x0_src;
dx = max(0.0f, min(dx, 1.0f));
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
const float val_a = *p_a;
const float val_b = *p_b;
const float val_c = *p_c;
const float val_d = *p_d;
float result = val_a * (1.0f - dx) * (1.0f - dy) +
val_b * dx * (1.0f - dy) +
val_c * (1.0f - dx) * dy +
val_d * dx * dy;
dst[index] = result;
}
static void upscale_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int ne13,
const float sf0, const float sf1, const float sf2, const float sf3,
cudaStream_t stream) {
int dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
}
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float sf0 = (float)dst->ne[0]/src0->ne[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1];
const float sf2 = (float)dst->ne[2]/src0->ne[2];
const int mode_flags = dst->op_params[0];
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
float sf0 = (float)dst->ne[0]/src0->ne[0];
float sf1 = (float)dst->ne[1]/src0->ne[1];
float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3];
if (mode == GGML_SCALE_MODE_NEAREST) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
pixel_offset = 0.0f;
}
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
}
}

View file

@ -2722,7 +2722,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@ -6276,13 +6276,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
if (workgroups_x == 1 && shader_core_count > 0) {
// Try to run two workgroups per SM.
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
@ -6416,7 +6416,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{

View file

@ -2,9 +2,9 @@
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 32
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
uint k_num;
} p;
shared float tmpsh[BLOCK_SIZE];
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
@ -32,23 +34,51 @@ void main() {
// Compute the max m value for the row
float m_max = -1.0/0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float m = data_a[m_offset + k * lm_stride];
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}
// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];
barrier();
// Compute L based on m_max
float L = 0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float l = data_a[l_offset + k * lm_stride];
float m = data_a[m_offset + k * lm_stride];
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}
// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
for (uint d = tid; d < D; d += BLOCK_SIZE) {
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;

View file

@ -14,21 +14,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
return;
}
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;

View file

@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
if (i0 >= p.n_dims) {
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
return;
}
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View file

@ -13,21 +13,19 @@ void main() {
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
if (i0 >= p.n_dims) {
data_d[idst + 0] = data_a[ix + 0];
data_d[idst + 1] = data_a[ix + 1];
return;
}
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View file

@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
FALCON_H1 = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
@ -357,6 +358,8 @@ class MODEL_ARCH(IntEnum):
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
HUNYUAN_MOE = auto()
SMOLLM3 = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -660,6 +663,9 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.SMOLLM3: "smollm3",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -2211,6 +2217,77 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,
# Input layernorm
MODEL_TENSOR.ATTN_NORM,
# Attention components
MODEL_TENSOR.ATTN_Q, # Query projection
MODEL_TENSOR.ATTN_K, # Key projection
MODEL_TENSOR.ATTN_V, # Value projection
MODEL_TENSOR.ATTN_OUT, # Output projection
# SSM components (Mamba2 specific)
MODEL_TENSOR.SSM_IN, # Input projection for SSM
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
MODEL_TENSOR.SSM_DT, # Delta time projection
MODEL_TENSOR.SSM_A, # A parameter (log form)
MODEL_TENSOR.SSM_D, # D parameter
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
MODEL_TENSOR.SSM_OUT, # Output projection
# Pre-feedforward layernorm
MODEL_TENSOR.FFN_PRE_NORM,
# Feed-forward network components
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
MODEL_TENSOR.FFN_DOWN, # Down projection
MODEL_TENSOR.FFN_UP, # Up projection
# Post-feedforward layernorm
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.SMOLLM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

View file

@ -286,12 +286,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"model.layers.{bid}.pre_ff_layernorm.weight",
),
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),
MODEL_TENSOR.FFN_GATE_INP: (
@ -303,6 +305,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -362,6 +365,8 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj",
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),
# AWQ-activation gate
@ -398,6 +403,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),
# Feed-forward down
@ -447,11 +453,13 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
@ -461,6 +469,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
@ -547,11 +556,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj",
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d",
),
MODEL_TENSOR.SSM_X: (
@ -562,25 +573,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj",
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log",
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D",
),
MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # falcon-h1
),
MODEL_TENSOR.TIME_MIX_W0: (

View file

@ -120,6 +120,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
};
enum llama_rope_type {

File diff suppressed because one or more lines are too long

View file

@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
@ -78,6 +79,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1022,6 +1025,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XVERSE,
{
@ -1694,12 +1721,52 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@ -1925,9 +1992,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
// List all mamba-attention hybrid models here
switch (arch) {
case LLM_ARCH_FALCON_H1:
return true;
default:
return false;
}

View file

@ -50,6 +50,7 @@ enum llm_arch {
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
@ -82,6 +83,8 @@ enum llm_arch {
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_UNKNOWN,
};

View file

@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|response|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else {
// template not supported
return -1;

View file

@ -44,6 +44,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

View file

@ -377,14 +377,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
ubatch = balloc.split_equal(n_ubatch, false);
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
if (!prepare(ubatches)) {
break;
}

File diff suppressed because it is too large Load diff

View file

@ -94,6 +94,7 @@ enum llm_type {
LLM_TYPE_57B_A14B,
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B,
LLM_TYPE_E2B,

View file

@ -576,6 +576,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@ -1758,6 +1759,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe"||
tokenizer_pre == "falcon3" ||
tokenizer_pre == "falcon-h1" ||
tokenizer_pre == "pixtral") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;
@ -1790,7 +1792,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "jina-de" ||
tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de") {
tokenizer_pre == "jina-v2-de" ||
tokenizer_pre == "a.x-4.0") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "jina-v1-en" ||
@ -1892,6 +1895,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "seed-coder") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
clean_spaces = false;
} else if (
tokenizer_pre == "hunyuan") {
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}

View file

@ -4806,14 +4806,14 @@ int main(int argc, char ** argv) {
// register static assets routes
if (!params.public_path.empty()) {
// Set the base directory for serving static files
bool is_found = svr->set_mount_point("/", params.public_path);
bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
if (!is_found) {
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
return 1;
}
} else {
// using embedded static index.html
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
res.set_content("Error: gzip is not supported by this browser", "text/plain");
} else {
@ -4829,37 +4829,37 @@ int main(int argc, char ** argv) {
}
// register API routes
svr->Get ("/health", handle_health); // public endpoint (no API key check)
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
svr->Post("/props", handle_props_change);
svr->Post("/api/show", handle_api_show);
svr->Get ("/models", handle_models); // public endpoint (no API key check)
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
svr->Post("/completion", handle_completions); // legacy
svr->Post("/completions", handle_completions);
svr->Post("/v1/completions", handle_completions_oai);
svr->Post("/chat/completions", handle_chat_completions);
svr->Post("/v1/chat/completions", handle_chat_completions);
svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint
svr->Post("/infill", handle_infill);
svr->Post("/embedding", handle_embeddings); // legacy
svr->Post("/embeddings", handle_embeddings);
svr->Post("/v1/embeddings", handle_embeddings_oai);
svr->Post("/rerank", handle_rerank);
svr->Post("/reranking", handle_rerank);
svr->Post("/v1/rerank", handle_rerank);
svr->Post("/v1/reranking", handle_rerank);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
svr->Post("/apply-template", handle_apply_template);
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
svr->Get (params.api_prefix + "/metrics", handle_metrics);
svr->Get (params.api_prefix + "/props", handle_props);
svr->Post(params.api_prefix + "/props", handle_props_change);
svr->Post(params.api_prefix + "/api/show", handle_api_show);
svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check)
svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check)
svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
svr->Post(params.api_prefix + "/completion", handle_completions); // legacy
svr->Post(params.api_prefix + "/completions", handle_completions);
svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai);
svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions);
svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions);
svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint
svr->Post(params.api_prefix + "/infill", handle_infill);
svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy
svr->Post(params.api_prefix + "/embeddings", handle_embeddings);
svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai);
svr->Post(params.api_prefix + "/rerank", handle_rerank);
svr->Post(params.api_prefix + "/reranking", handle_rerank);
svr->Post(params.api_prefix + "/v1/rerank", handle_rerank);
svr->Post(params.api_prefix + "/v1/reranking", handle_rerank);
svr->Post(params.api_prefix + "/tokenize", handle_tokenize);
svr->Post(params.api_prefix + "/detokenize", handle_detokenize);
svr->Post(params.api_prefix + "/apply-template", handle_apply_template);
// LoRA adapters hotswap
svr->Get ("/lora-adapters", handle_lora_adapters_list);
svr->Post("/lora-adapters", handle_lora_adapters_apply);
svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list);
svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply);
// Save & load slots
svr->Get ("/slots", handle_slots);
svr->Post("/slots/:id_slot", handle_slots_action);
svr->Get (params.api_prefix + "/slots", handle_slots);
svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
//
// Start the server