mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
i'm gonna regret this, aren't i?
This commit is contained in:
commit
12e6928ec2
14 changed files with 237 additions and 81 deletions
|
@ -5746,11 +5746,20 @@ class GraniteModel(LlamaModel):
|
||||||
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("GraniteMoeForCausalLM")
|
@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM")
|
||||||
class GraniteMoeModel(GraniteModel):
|
class GraniteMoeModel(GraniteModel):
|
||||||
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
||||||
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
"""GraniteMoeShared uses GraniteMoe parameters plus the following:
|
||||||
|
- shared_intermediate_size
|
||||||
|
"""
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"):
|
||||||
|
self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length)
|
||||||
|
logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
||||||
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
||||||
|
@ -5761,12 +5770,21 @@ class GraniteMoeModel(GraniteModel):
|
||||||
if name.endswith("block_sparse_moe.input_linear.weight"):
|
if name.endswith("block_sparse_moe.input_linear.weight"):
|
||||||
ffn_dim = self.hparams["intermediate_size"]
|
ffn_dim = self.hparams["intermediate_size"]
|
||||||
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
|
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
|
||||||
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
|
gate, up = data_torch.split(ffn_dim, dim=-2)
|
||||||
return [
|
return [
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
|
||||||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if name.endswith("shared_mlp.input_linear.weight"):
|
||||||
|
ffn_dim = self.hparams["shared_intermediate_size"]
|
||||||
|
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
|
||||||
|
gate, up = data_torch.split(ffn_dim, dim=-2)
|
||||||
|
return [
|
||||||
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
|
||||||
|
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
|
||||||
|
]
|
||||||
|
|
||||||
return super().modify_tensors(data_torch, name, bid)
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1905,6 +1905,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_GATE_EXP,
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.CHAMELEON: [
|
MODEL_ARCH.CHAMELEON: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
|
|
@ -428,6 +428,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||||
|
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||||
|
|
|
@ -347,7 +347,7 @@ extern "C" {
|
||||||
float yarn_beta_fast; // YaRN low correction dim
|
float yarn_beta_fast; // YaRN low correction dim
|
||||||
float yarn_beta_slow; // YaRN high correction dim
|
float yarn_beta_slow; // YaRN high correction dim
|
||||||
uint32_t yarn_orig_ctx; // YaRN original context size
|
uint32_t yarn_orig_ctx; // YaRN original context size
|
||||||
float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default)
|
float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval;
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
void * cb_eval_user_data;
|
void * cb_eval_user_data;
|
||||||
|
|
|
@ -1481,6 +1481,9 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||||
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -1394,6 +1394,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
// Add additional layer/vocab/etc checks here for other model sizes
|
// Add additional layer/vocab/etc checks here for other model sizes
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For Granite MoE Shared
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
{
|
{
|
||||||
|
@ -1830,6 +1833,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For Granite MoE Shared
|
||||||
|
if (hparams.n_ff_shexp > 0) {
|
||||||
|
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||||
|
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
|
||||||
|
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -4482,10 +4492,13 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
|
if (arch == LLM_ARCH_MINICPM ||
|
||||||
|
arch == LLM_ARCH_GRANITE ||
|
||||||
|
arch == LLM_ARCH_GRANITE_MOE) {
|
||||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||||
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_BAILINGMOE) {
|
if (arch == LLM_ARCH_BAILINGMOE) {
|
||||||
|
@ -4702,11 +4715,6 @@ struct llm_build_llama : public llm_graph_context {
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_residual_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
@ -4778,11 +4786,6 @@ struct llm_build_llama : public llm_graph_context {
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(cur, "ffn_moe_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_residual_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
@ -4805,11 +4808,6 @@ struct llm_build_llama : public llm_graph_context {
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_logit_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
@ -4920,11 +4918,6 @@ struct llm_build_deci : public llm_graph_context {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_residual_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
|
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
|
||||||
ggml_tensor * ffn_inp = cur;
|
ggml_tensor * ffn_inp = cur;
|
||||||
if (n_head > 0) {
|
if (n_head > 0) {
|
||||||
|
@ -4948,11 +4941,6 @@ struct llm_build_deci : public llm_graph_context {
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_residual_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
@ -4975,11 +4963,6 @@ struct llm_build_deci : public llm_graph_context {
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
// For Granite architecture
|
|
||||||
if (hparams.f_logit_scale) {
|
|
||||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
@ -12318,6 +12301,195 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct llm_build_granite : public llm_graph_context {
|
||||||
|
llm_build_granite(
|
||||||
|
const llama_model & model,
|
||||||
|
const llm_graph_params & params,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
const bool use_rope = true)
|
||||||
|
: llm_graph_context(params) {
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
|
ggml_tensor * cur;
|
||||||
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
|
// inp_pos - built only if rope enabled
|
||||||
|
ggml_tensor * inp_pos = nullptr;
|
||||||
|
|
||||||
|
auto * inp_attn = build_attn_inp_kv_unified();
|
||||||
|
|
||||||
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
// norm
|
||||||
|
cur = build_norm(inpL,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and (optionally) RoPE them
|
||||||
|
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
|
if (use_rope) {
|
||||||
|
|
||||||
|
if (!inp_pos) {
|
||||||
|
inp_pos = build_inp_pos();
|
||||||
|
}
|
||||||
|
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, Qcur, inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, Kcur, inp_pos, rope_factors,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
cur = build_attn(inp_attn, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||||
|
cb(cur, "attn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Granite architectures - scale residual
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward network (non-MoE)
|
||||||
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||||
|
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// MoE branch
|
||||||
|
cur = build_norm(ffn_inp,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||||
|
model.layers[il].ffn_gate_inp,
|
||||||
|
model.layers[il].ffn_up_exps,
|
||||||
|
model.layers[il].ffn_gate_exps,
|
||||||
|
model.layers[il].ffn_down_exps,
|
||||||
|
nullptr,
|
||||||
|
n_expert, n_expert_used,
|
||||||
|
LLM_FFN_SILU, true,
|
||||||
|
false, 0.0,
|
||||||
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
|
il);
|
||||||
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
|
// For Granite MoE Shared
|
||||||
|
if (hparams.n_ff_shexp > 0) {
|
||||||
|
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||||
|
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||||
|
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
|
cb(ffn_shexp, "ffn_shexp", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
} else {
|
||||||
|
cur = moe_out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Granite architectures - scale residual
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = build_cvec(cur, il);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
res->t_embd = cur;
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
|
// For Granite architectures - scale logits
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
res->t_logits = cur;
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// ref: https://github.com/facebookresearch/chameleon
|
// ref: https://github.com/facebookresearch/chameleon
|
||||||
// based on the original build_llama() function, changes:
|
// based on the original build_llama() function, changes:
|
||||||
// * qk-norm
|
// * qk-norm
|
||||||
|
@ -13025,8 +13197,6 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||||
case LLM_ARCH_LLAMA:
|
case LLM_ARCH_LLAMA:
|
||||||
case LLM_ARCH_LLAMA4:
|
case LLM_ARCH_LLAMA4:
|
||||||
case LLM_ARCH_MINICPM:
|
case LLM_ARCH_MINICPM:
|
||||||
case LLM_ARCH_GRANITE:
|
|
||||||
case LLM_ARCH_GRANITE_MOE:
|
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
llm = std::make_unique<llm_build_llama>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
|
@ -13257,6 +13427,11 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
|
llm = std::make_unique<llm_build_arwkv7>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_GRANITE:
|
||||||
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
|
{
|
||||||
|
llm = std::make_unique<llm_build_granite>(*this, params, gf);
|
||||||
|
} break;
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
# Quantizing CLIP Visual Projector
|
|
||||||
|
|
||||||
This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To quantize a CLIP visual projector model, use the following command:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf <type>
|
|
||||||
```
|
|
||||||
|
|
||||||
After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc).
|
|
||||||
|
|
||||||
### Arguments
|
|
||||||
|
|
||||||
- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format.
|
|
||||||
- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved.
|
|
||||||
- `<type>`: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`.
|
|
||||||
|
|
||||||
### Quantization Types
|
|
||||||
|
|
||||||
The following quantization types are supported, based on the `enum ggml_type` definition:
|
|
||||||
|
|
||||||
- `2` - `q4_0`: 4-bit quantization with a single scale value.
|
|
||||||
- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block.
|
|
||||||
- `6` - `q5_0`: 5-bit quantization with a single scale value.
|
|
||||||
- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block.
|
|
||||||
- `8` - `q8_0`: 8-bit quantization with a single scale value.
|
|
||||||
|
|
||||||
### Example
|
|
||||||
|
|
||||||
To quantize a model using the `q4_0` quantization type, you would run:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2
|
|
||||||
```
|
|
||||||
|
|
||||||
This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method.
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements.
|
|
||||||
- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments.
|
|
Loading…
Add table
Add a link
Reference in a new issue