mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # docs/development/HOWTO-add-model.md # ggml/src/ggml-sycl/rope.cpp # tests/test-backend-ops.cpp
This commit is contained in:
commit
b8c1fc7c9e
30 changed files with 1784 additions and 263 deletions
|
@ -2736,6 +2736,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.public_path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
|
||||
add_opt(common_arg(
|
||||
{"--api-prefix"}, "PREFIX",
|
||||
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.api_prefix = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||
add_opt(common_arg(
|
||||
{"--no-webui"},
|
||||
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),
|
||||
|
|
|
@ -366,6 +366,7 @@ struct common_params {
|
|||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string api_prefix = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
|
|
|
@ -815,6 +815,24 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
|
||||
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
|
||||
res = "hunyuan"
|
||||
if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf":
|
||||
# ref: https://huggingface.co/skt/A.X-4.0
|
||||
res = "a.x-4.0"
|
||||
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
|
||||
res = "falcon-h1"
|
||||
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
|
||||
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
|
||||
res = "falcon-h1"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -4896,15 +4914,17 @@ class Mamba2Model(TextModel):
|
|||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
|
||||
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
|
||||
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
|
||||
n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
# skip the assertion for FalconH1 Model
|
||||
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
|
||||
assert d_inner == 2 * d_model
|
||||
assert d_inner % head_dim == 0
|
||||
|
||||
|
@ -4943,7 +4963,7 @@ class Mamba2Model(TextModel):
|
|||
data_torch = data_torch.reshape((*data_torch.shape, 1))
|
||||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
n_group = self.hparams.get("n_groups", 1)
|
||||
data_torch = data_torch.reshape((n_group, d_inner // n_group))
|
||||
|
||||
|
@ -6535,6 +6555,277 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
|||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
|
||||
@ModelBase.register("FalconH1ForCausalLM")
|
||||
class FalconH1Model(Mamba2Model):
|
||||
model_arch = gguf.MODEL_ARCH.FALCON_H1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Set the hparam prefixes for Falcon Mamba2
|
||||
self.hparam_prefixes = ["mamba"]
|
||||
|
||||
# Initialize the base Mamba2Model
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Use Llama conversion for attention
|
||||
self._transformer_model_class = LlamaModel
|
||||
|
||||
# n_group and d_inner are used during reshape_tensors for mamaba2
|
||||
self.n_group = self.find_hparam(["n_groups"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm"])
|
||||
self.d_head = self.find_hparam(["d_head"])
|
||||
|
||||
# Initialize any Falcon Mamba2 specific attributes
|
||||
self.has_attention = True # Falcon Mamba2 has attention components
|
||||
|
||||
# Load Falcon-H1 multipliers from hyperparameters
|
||||
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
|
||||
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
|
||||
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
|
||||
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
|
||||
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
|
||||
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
|
||||
self.intermediate_size = self.find_hparam(["intermediate_size"])
|
||||
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
|
||||
|
||||
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
|
||||
prefixed = []
|
||||
for pfx in self.hparam_prefixes:
|
||||
prefixed.extend(
|
||||
"_".join([pfx, k])
|
||||
for k in keys
|
||||
)
|
||||
keys = list(keys) + prefixed
|
||||
return super().find_hparam(keys, *args, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
tensors = list(super().modify_tensors(data_torch, name, bid))
|
||||
tensor = tensors[0][1]
|
||||
|
||||
if "down_proj" in name:
|
||||
tensor = tensor * self.mlp_multipliers[1]
|
||||
elif "gate_proj" in name:
|
||||
tensor = tensor * self.mlp_multipliers[0]
|
||||
elif "k_proj" in name:
|
||||
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
|
||||
elif "q_proj" in name:
|
||||
tensor = tensor * self.attention_in_multiplier
|
||||
elif "v_proj" in name:
|
||||
tensor = tensor * self.attention_in_multiplier
|
||||
elif "o_proj" in name:
|
||||
tensor = tensor * self.attention_out_multiplier
|
||||
elif "out_proj" in name:
|
||||
tensor = tensor * self.ssm_out_multiplier
|
||||
elif "in_proj" in name:
|
||||
tensor = tensor * self.ssm_in_multiplier
|
||||
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
|
||||
intermediate_size = self.hparams["mamba_d_ssm"]
|
||||
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
|
||||
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
|
||||
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
|
||||
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
|
||||
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
|
||||
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
|
||||
elif "lm_head" in name:
|
||||
tensor = tensor * self.hparams["lm_head_multiplier"]
|
||||
elif "embed_tokens" in name:
|
||||
tensor = tensor * self.hparams["embedding_multiplier"]
|
||||
elif "mamba.norm" in name:
|
||||
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
|
||||
|
||||
tensors = [(tensors[0][0], tensor)]
|
||||
return tensors
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
## General Params ##
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
# Override some Mamba2 defaults
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
|
||||
## Attention params ##
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["head_dim"])
|
||||
self.gguf_writer.add_value_length(self.hparams["head_dim"])
|
||||
|
||||
## Validation ##
|
||||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
|
||||
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
|
||||
|
||||
# Add any other Falcon Mamba2 specific configuration
|
||||
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanMoEV1ForCausalLM")
|
||||
class HunYuanMoEModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# For handling tied embeddings
|
||||
self._tok_embd = None
|
||||
|
||||
def set_vocab(self):
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
|
||||
|
||||
# 1. Get the pre-tokenizer identifier hash
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
# 2. Reverse-engineer the merges list from mergeable_ranks
|
||||
merges = []
|
||||
vocab = {}
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[QwenModel.token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
|
||||
if len(merged) == 2: # todo this is an assert in Qwen, why?
|
||||
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
|
||||
|
||||
# 3. Generate the tokens and toktypes lists
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
assert tokenizer.vocab_size == vocab_size
|
||||
special_tokens = tokenizer.special_tokens
|
||||
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token = reverse_vocab[i]
|
||||
tokens.append(token)
|
||||
if i in special_tokens.values():
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
# 4. Write all vocab-related fields to the GGUF writer
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_token_merges(merges)
|
||||
|
||||
# 5. Add special tokens and chat templates
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
# FIX for BOS token: Overwrite incorrect id read from config.json
|
||||
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
|
||||
|
||||
moe_intermediate_size = hparams["moe_intermediate_size"]
|
||||
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
|
||||
|
||||
moe_topk = hparams["moe_topk"]
|
||||
assert all(topk == moe_topk[0] for topk in moe_topk)
|
||||
self.gguf_writer.add_expert_used_count(moe_topk[0])
|
||||
|
||||
moe_shared_expert = hparams["num_shared_expert"]
|
||||
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
|
||||
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
|
||||
|
||||
# Rope
|
||||
rope_scaling = hparams.get("rope_scaling", {})
|
||||
if rope_scaling.get("type") == "dynamic":
|
||||
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
|
||||
alpha = rope_scaling.get("alpha", 1000)
|
||||
base = hparams.get("rope_theta", 10000.0)
|
||||
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
|
||||
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
|
||||
self.gguf_writer.add_rope_freq_base(scaled_base)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_rope_scaling_factor(1)
|
||||
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
|
||||
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
|
||||
|
||||
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
|
||||
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
|
||||
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "model.embed_tokens.weight":
|
||||
self._tok_embd = data_torch.clone()
|
||||
|
||||
if name == "lm_head.weight":
|
||||
if self.hparams.get("tie_word_embeddings", False):
|
||||
logger.info("Skipping tied output layer 'lm_head.weight'")
|
||||
return []
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
if self._experts is not None:
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("SmolLM3ForCausalLM")
|
||||
class SmolLM3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMOLLM3
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
# remove unsupported array slicing in chat template
|
||||
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
if tokenizer.chat_template is not None:
|
||||
chat_template = tokenizer.chat_template.replace("[:]", "")
|
||||
self.gguf_writer.add_chat_template(chat_template)
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -128,6 +128,7 @@ models = [
|
|||
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
|
||||
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
|
||||
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
|
||||
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
@ -137,6 +138,12 @@ pre_computed_hashes = [
|
|||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
|
||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -501,7 +501,7 @@ extern "C" {
|
|||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_POOL_2D_BACK,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_UPSCALE,
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_PAD_REFLECT_1D,
|
||||
GGML_OP_ROLL,
|
||||
|
|
|
@ -180,9 +180,9 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#endif
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
|
@ -190,7 +190,10 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
GGML_UNUSED(nbytes); \
|
||||
} while (0)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||
|
||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||
|
|
|
@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
|
|
@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
|
|
@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|||
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_BF16:
|
||||
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
|
@ -210,6 +214,10 @@ void get_rows_cuda(
|
|||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
|
|
|
@ -3205,6 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
@ -3378,7 +3380,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_UPSCALE:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
|
|
|
@ -50,21 +50,19 @@ static __global__ void rope_norm(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + 0] = x[ix + 0];
|
||||
dst[idst + 1] = x[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
@ -94,21 +92,19 @@ static __global__ void rope_neox(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
@ -138,21 +134,19 @@ static __global__ void rope_multi(
|
|||
|
||||
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
const int i = row_dst*ne0 + i0;
|
||||
|
||||
dst[i + 0] = x[i + 0];
|
||||
dst[i + 1] = x[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
|
||||
const int idst = row_dst*ne0 + i0/2;
|
||||
const int ix = channel_x*s2 + row_x*s1 + i0/2;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
|
||||
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
|
|
@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst,
|
|||
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
||||
}
|
||||
|
||||
static __global__ void upscale_f32_bilinear(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset) {
|
||||
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
|
||||
if (index >= dst_total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i10_dst = index % ne10_dst;
|
||||
const int i11_dst = (index / ne10_dst) % ne11_dst;
|
||||
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
|
||||
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
|
||||
|
||||
const int i02_src = (int)(i12_dst / sf2);
|
||||
const int i03_src = (int)(i13_dst / sf3);
|
||||
|
||||
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
|
||||
int y0_src = (int)floorf(y_src_f);
|
||||
int y1_src = y0_src + 1;
|
||||
|
||||
y0_src = max(0, min(y0_src, ne01_src - 1));
|
||||
y1_src = max(0, min(y1_src, ne01_src - 1));
|
||||
|
||||
float dy = y_src_f - (float)y0_src;
|
||||
dy = max(0.0f, min(dy, 1.0f));
|
||||
|
||||
float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
|
||||
int x0_src = (int)floorf(x_src_f);
|
||||
int x1_src = x0_src + 1;
|
||||
|
||||
x0_src = max(0, min(x0_src, ne00_src - 1));
|
||||
x1_src = max(0, min(x1_src, ne00_src - 1));
|
||||
|
||||
float dx = x_src_f - (float)x0_src;
|
||||
dx = max(0.0f, min(dx, 1.0f));
|
||||
|
||||
const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
|
||||
|
||||
const float val_a = *p_a;
|
||||
const float val_b = *p_b;
|
||||
const float val_c = *p_c;
|
||||
const float val_d = *p_d;
|
||||
|
||||
float result = val_a * (1.0f - dx) * (1.0f - dy) +
|
||||
val_b * dx * (1.0f - dy) +
|
||||
val_c * (1.0f - dx) * dy +
|
||||
val_d * dx * dy;
|
||||
|
||||
dst[index] = result;
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne10, const int ne11, const int ne12, const int ne13,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
cudaStream_t stream) {
|
||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
|
||||
}
|
||||
|
||||
static void upscale_f32_bilinear_cuda(const float * x, float * dst,
|
||||
const int nb00, const int nb01, const int nb02, const int nb03,
|
||||
const int ne00_src, const int ne01_src,
|
||||
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
|
||||
const float sf0, const float sf1, const float sf2, const float sf3,
|
||||
const float pixel_offset, cudaStream_t stream) {
|
||||
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
|
||||
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
|
||||
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const int mode_flags = dst->op_params[0];
|
||||
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
|
||||
|
||||
float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
|
||||
if (mode == GGML_SCALE_MODE_NEAREST) {
|
||||
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
|
||||
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
||||
float pixel_offset = 0.5f;
|
||||
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
||||
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
|
||||
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
|
||||
pixel_offset = 0.0f;
|
||||
}
|
||||
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
sf0, sf1, sf2, sf3, pixel_offset, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2722,7 +2722,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||
|
@ -6276,13 +6276,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
if (workgroups_x == 1 && shader_core_count > 0) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
workgroups_x = split_k;
|
||||
}
|
||||
|
@ -6416,7 +6416,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
},
|
||||
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||
|
@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
|
|||
uint k_num;
|
||||
} p;
|
||||
|
||||
shared float tmpsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
|
@ -32,23 +34,51 @@ void main() {
|
|||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
m_max = max(m_max, m);
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = m_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
m_max = max(m_max, tmpsh[tid + s]);
|
||||
tmpsh[tid] = m_max;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
m_max = tmpsh[0];
|
||||
|
||||
barrier();
|
||||
|
||||
// Compute L based on m_max
|
||||
float L = 0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
float l = data_a[l_offset + k * lm_stride];
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
||||
float l = data_a[l_offset + (k + tid) * lm_stride];
|
||||
float m = data_a[m_offset + (k + tid) * lm_stride];
|
||||
L += exp(m - m_max) * l;
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
tmpsh[tid] = L;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
L += tmpsh[tid + s];
|
||||
tmpsh[tid] = L;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
L = tmpsh[0];
|
||||
|
||||
L = 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
// Scale and sum the O contributions based on m_max and store the result to memory
|
||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
|
|
|
@ -14,21 +14,19 @@ void main() {
|
|||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
||||
const int sec_w = p.sections[1] + p.sections[0];
|
||||
const uint sector = (i0 / 2) % sect_dims;
|
||||
|
|
|
@ -13,21 +13,19 @@ void main() {
|
|||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0/2;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
||||
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
|
|
@ -13,21 +13,19 @@ void main() {
|
|||
|
||||
const uint row_dst = gl_GlobalInvocationID.x;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
const uint i = row_dst*ne0 + i0;
|
||||
|
||||
data_d[i + 0] = data_a[i + 0];
|
||||
data_d[i + 1] = data_a[i + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const uint row_x = row_dst % ne1;
|
||||
const uint channel_x = row_dst / ne1;
|
||||
|
||||
const uint idst = row_dst*ne0 + i0;
|
||||
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
||||
|
||||
if (i0 >= p.n_dims) {
|
||||
data_d[idst + 0] = data_a[ix + 0];
|
||||
data_d[idst + 1] = data_a[ix + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
||||
|
||||
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
||||
|
|
|
@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
|
|||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
FALCON = auto()
|
||||
FALCON_H1 = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
GPT2 = auto()
|
||||
|
@ -357,6 +358,8 @@ class MODEL_ARCH(IntEnum):
|
|||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
SMOLLM3 = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
@ -660,6 +663,9 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
|
||||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
@ -2211,6 +2217,77 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.FALCON_H1: [
|
||||
# Token embedding
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
||||
# Input layernorm
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
|
||||
# Attention components
|
||||
MODEL_TENSOR.ATTN_Q, # Query projection
|
||||
MODEL_TENSOR.ATTN_K, # Key projection
|
||||
MODEL_TENSOR.ATTN_V, # Value projection
|
||||
MODEL_TENSOR.ATTN_OUT, # Output projection
|
||||
|
||||
# SSM components (Mamba2 specific)
|
||||
MODEL_TENSOR.SSM_IN, # Input projection for SSM
|
||||
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
|
||||
MODEL_TENSOR.SSM_DT, # Delta time projection
|
||||
MODEL_TENSOR.SSM_A, # A parameter (log form)
|
||||
MODEL_TENSOR.SSM_D, # D parameter
|
||||
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
|
||||
MODEL_TENSOR.SSM_OUT, # Output projection
|
||||
|
||||
# Pre-feedforward layernorm
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
|
||||
# Feed-forward network components
|
||||
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
|
||||
MODEL_TENSOR.FFN_DOWN, # Down projection
|
||||
MODEL_TENSOR.FFN_UP, # Up projection
|
||||
|
||||
# Post-feedforward layernorm
|
||||
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
|
||||
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
|
||||
],
|
||||
MODEL_ARCH.HUNYUAN_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.SMOLLM3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
@ -286,12 +286,14 @@ class TensorNameMap:
|
|||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||
"model.layers.{bid}.feed_forward.up_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
|
@ -303,6 +305,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
@ -362,6 +365,8 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
|
@ -398,6 +403,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
|
@ -447,11 +453,13 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
|
@ -461,6 +469,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
|
@ -547,11 +556,13 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
"model.layers.{bid}.mamba.in_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
"model.layers.{bid}.mamba.conv1d",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
|
@ -562,25 +573,30 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
"model.layers.{bid}.mamba.dt_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
"model.layers.{bid}.mamba.A_log",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
"model.layers.{bid}.mamba.D",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_NORM: (
|
||||
"model.layers.{bid}.mamba.norm", # falcon-h1
|
||||
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
"model.layers.{bid}.mamba.out_proj", # falcon-h1
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W0: (
|
||||
|
|
|
@ -120,6 +120,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
||||
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
||||
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
|
267
klite.embd
267
klite.embd
File diff suppressed because one or more lines are too long
|
@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_MAMBA2, "mamba2" },
|
||||
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||
|
@ -78,6 +79,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -1022,6 +1025,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_FALCON_H1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_XVERSE,
|
||||
{
|
||||
|
@ -1694,12 +1721,52 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMOLLM3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
|
@ -1925,9 +1992,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
|
|||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
// List all mamba-attention hybrid models here
|
||||
switch (arch) {
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ enum llm_arch {
|
|||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_MAMBA2,
|
||||
LLM_ARCH_FALCON_H1,
|
||||
LLM_ARCH_XVERSE,
|
||||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_COHERE2,
|
||||
|
@ -82,6 +83,8 @@ enum llm_arch {
|
|||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
|
@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
|
@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
|
|||
if (add_ass) {
|
||||
ss << "<|response|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
||||
// tencent/Hunyuan-A13B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
||||
} else if (role == "assistant") {
|
||||
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
||||
} else {
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
|
|
@ -44,6 +44,7 @@ enum llm_chat_template {
|
|||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
@ -377,14 +377,18 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
if (!prepare(ubatches)) {
|
||||
break;
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -94,6 +94,7 @@ enum llm_type {
|
|||
LLM_TYPE_57B_A14B,
|
||||
LLM_TYPE_17B_16E, // llama4 Scout
|
||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||
LLM_TYPE_A13B,
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_235B_A22B,
|
||||
LLM_TYPE_E2B,
|
||||
|
|
|
@ -576,6 +576,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
|
@ -1758,6 +1759,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "llama-v3" ||
|
||||
tokenizer_pre == "llama-bpe"||
|
||||
tokenizer_pre == "falcon3" ||
|
||||
tokenizer_pre == "falcon-h1" ||
|
||||
tokenizer_pre == "pixtral") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
||||
ignore_merges = true;
|
||||
|
@ -1790,7 +1792,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "jina-de" ||
|
||||
tokenizer_pre == "gigachat" ||
|
||||
tokenizer_pre == "jina-v2-es" ||
|
||||
tokenizer_pre == "jina-v2-de") {
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "a.x-4.0") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
|
@ -1892,6 +1895,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "seed-coder") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "hunyuan") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
|
|
@ -4806,14 +4806,14 @@ int main(int argc, char ** argv) {
|
|||
// register static assets routes
|
||||
if (!params.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
bool is_found = svr->set_mount_point("/", params.public_path);
|
||||
bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
|
||||
if (!is_found) {
|
||||
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
// using embedded static index.html
|
||||
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||
} else {
|
||||
|
@ -4829,37 +4829,37 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// register API routes
|
||||
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
||||
svr->Get ("/metrics", handle_metrics);
|
||||
svr->Get ("/props", handle_props);
|
||||
svr->Post("/props", handle_props_change);
|
||||
svr->Post("/api/show", handle_api_show);
|
||||
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
|
||||
svr->Post("/completion", handle_completions); // legacy
|
||||
svr->Post("/completions", handle_completions);
|
||||
svr->Post("/v1/completions", handle_completions_oai);
|
||||
svr->Post("/chat/completions", handle_chat_completions);
|
||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint
|
||||
svr->Post("/infill", handle_infill);
|
||||
svr->Post("/embedding", handle_embeddings); // legacy
|
||||
svr->Post("/embeddings", handle_embeddings);
|
||||
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||
svr->Post("/rerank", handle_rerank);
|
||||
svr->Post("/reranking", handle_rerank);
|
||||
svr->Post("/v1/rerank", handle_rerank);
|
||||
svr->Post("/v1/reranking", handle_rerank);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/apply-template", handle_apply_template);
|
||||
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
|
||||
svr->Get (params.api_prefix + "/metrics", handle_metrics);
|
||||
svr->Get (params.api_prefix + "/props", handle_props);
|
||||
svr->Post(params.api_prefix + "/props", handle_props_change);
|
||||
svr->Post(params.api_prefix + "/api/show", handle_api_show);
|
||||
svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check)
|
||||
svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check)
|
||||
svr->Post(params.api_prefix + "/completion", handle_completions); // legacy
|
||||
svr->Post(params.api_prefix + "/completions", handle_completions);
|
||||
svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai);
|
||||
svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions);
|
||||
svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions);
|
||||
svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint
|
||||
svr->Post(params.api_prefix + "/infill", handle_infill);
|
||||
svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy
|
||||
svr->Post(params.api_prefix + "/embeddings", handle_embeddings);
|
||||
svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai);
|
||||
svr->Post(params.api_prefix + "/rerank", handle_rerank);
|
||||
svr->Post(params.api_prefix + "/reranking", handle_rerank);
|
||||
svr->Post(params.api_prefix + "/v1/rerank", handle_rerank);
|
||||
svr->Post(params.api_prefix + "/v1/reranking", handle_rerank);
|
||||
svr->Post(params.api_prefix + "/tokenize", handle_tokenize);
|
||||
svr->Post(params.api_prefix + "/detokenize", handle_detokenize);
|
||||
svr->Post(params.api_prefix + "/apply-template", handle_apply_template);
|
||||
// LoRA adapters hotswap
|
||||
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
||||
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
||||
svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list);
|
||||
svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply);
|
||||
// Save & load slots
|
||||
svr->Get ("/slots", handle_slots);
|
||||
svr->Post("/slots/:id_slot", handle_slots_action);
|
||||
svr->Get (params.api_prefix + "/slots", handle_slots);
|
||||
svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue