merge but voxtral not working

This commit is contained in:
Concedo 2025-07-28 22:08:05 +08:00
commit b8425f5a9c
23 changed files with 1040 additions and 140 deletions

View file

@ -1900,6 +1900,7 @@ class StableLMModel(TextModel):
"MixtralForCausalLM", "MixtralForCausalLM",
"VLlama3ForCausalLM", "VLlama3ForCausalLM",
"LlavaForConditionalGeneration", "LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"LlamaModel") "LlamaModel")
class LlamaModel(TextModel): class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA model_arch = gguf.MODEL_ARCH.LLAMA
@ -1912,6 +1913,11 @@ class LlamaModel(TextModel):
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self): def set_vocab(self):
path_tekken_json = self.dir_model / "tekken.json"
path_tokenizer_json = self.dir_model / "tokenizer.json"
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
return self.set_vocab_tekken()
try: try:
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()
except FileNotFoundError: except FileNotFoundError:
@ -1944,6 +1950,52 @@ class LlamaModel(TextModel):
if self.hparams.get("vocab_size", 32000) == 49152: if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)
def set_vocab_tekken(self):
vocab = gguf.vocab.MistralVocab(self.dir_model)
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size, (
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
)
if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
self.gguf_writer.add_tokenizer_pre("tekken")
self.gguf_writer.add_token_merges(
vocab.extract_vocab_merges_from_model()
)
logger.info(
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
)
self.gguf_writer.add_bos_token_id(vocab.bos_id)
self.gguf_writer.add_eos_token_id(vocab.eos_id)
self.gguf_writer.add_unk_token_id(vocab.unk_id)
self.gguf_writer.add_pad_token_id(vocab.pad_id)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_vocab_size(vocab.vocab_size)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
script_dir = Path(__file__).parent
template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
with open(template_path, "r", encoding="utf-8") as f:
template = f.read()
self.gguf_writer.add_chat_template(template)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
@ -1971,12 +2023,13 @@ class LlamaModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"] n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads") n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name \ is_multimodal_tensor = "vision_tower" in name \
or "vision_model" in name \ or "vision_model" in name \
or "audio_tower" in name \
or "model.connector" in name \ or "model.connector" in name \
or "multi_modal_projector" in name or "multi_modal_projector" in name
if is_vision_tensor: if is_multimodal_tensor:
return [] # skip vision tensors return [] # skip vision tensors
elif self.hf_arch == "LlamaModel": elif self.hf_arch == "LlamaModel":
name = "model." + name name = "model." + name
@ -7231,6 +7284,7 @@ class WhisperEncoderModel(MmprojModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
self.hparams["hidden_size"] = self.hparams["d_model"] self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
@ -7272,9 +7326,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("VoxtralForConditionalGeneration")
class VoxtralWhisperEncoderModel(WhisperEncoderModel):
has_vision_encoder = False # no vision encoder
has_audio_encoder = True
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VOXTRAL)
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
@ModelBase.register("FalconH1ForCausalLM") @ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model): class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1 model_arch = gguf.MODEL_ARCH.FALCON_H1
@ -7589,6 +7655,88 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
self.gguf_writer.add_feed_forward_length(moe_intermediate_size)
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
if (self.hparams.get('moe_primary_router_apply_softmax')):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
# YaRN is not enabled by default
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
sliding_window_layout = self.hparams.get("sliding_window_layout")
if sliding_window_layout:
for i in sliding_window_layout:
if i != 0:
sliding_window = self.hparams.get("sliding_window_size")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
break
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["down", "gate", "up"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View file

@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
K += blockIdx.y*D * nb11; K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21; V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D; maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values: // Calculate KQ tile and keep track of new maximum KQ values:
if (mask) { if (mask) {
@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
} }
} }
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads(); __syncthreads();
} }

View file

@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
K += blockIdx.y*D * nb11; K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21; V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D; maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values: // Calculate KQ tile and keep track of new maximum KQ values:
if (mask) { if (mask) {
@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
} }
} }
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads(); __syncthreads();
} }

View file

@ -500,6 +500,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_conv2d_f32; vk_pipeline pipeline_conv2d_f32;
vk_pipeline pipeline_conv2d_f16_f32;
vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_whcn_f32;
vk_pipeline pipeline_conv2d_dw_cwhn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@ -3090,12 +3091,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
} else { } else {
ggml_vk_create_pipeline( ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false); false);
ggml_vk_create_pipeline(
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
false);
} }
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@ -6982,9 +6992,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
} }
return nullptr; return nullptr;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (src0->type == GGML_TYPE_F32) {
return ctx->device->pipeline_conv2d_f32; return ctx->device->pipeline_conv2d_f32;
} else if (src0->type == GGML_TYPE_F16) {
return ctx->device->pipeline_conv2d_f16_f32;
}
} }
return nullptr; return nullptr;
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
@ -7906,6 +7920,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
// Skip empty skip_rows operations. For most ops the empty check at the start
// of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
// with empty srcs.
if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
return;
}
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
(uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
@ -8202,13 +8223,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -10891,7 +10912,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
const vk_device& device = ggml_vk_get_device(ctx->device); const vk_device& device = ggml_vk_get_device(ctx->device);
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
// Channel-contiguous format is not supported yet. // Channel-contiguous format is not supported yet.
return (op->src[0]->type == GGML_TYPE_F32 && return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[0]) &&

View file

@ -670,6 +670,7 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

View file

@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum):
SMOLLM3 = auto() SMOLLM3 = auto()
LFM2 = auto() LFM2 = auto()
DREAM = auto() DREAM = auto()
SMALLTHINKER = auto()
class VISION_PROJECTOR_TYPE(IntEnum): class VISION_PROJECTOR_TYPE(IntEnum):
@ -695,6 +696,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.DREAM: "dream", MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
} }
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@ -2483,6 +2485,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_OUT,
], ],
MODEL_ARCH.SMALLTHINKER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO # TODO
} }
@ -2704,6 +2724,7 @@ class VisionProjectorType:
INTERNVL = "internvl" INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio QWEN2A = "qwen2a" # audio
QWEN25O = "qwen2.5o" # omni QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
# Items here are (block size, type size) # Items here are (block size, type size)

View file

@ -317,6 +317,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.router", # llama4 jamba "model.layers.{bid}.feed_forward.router", # llama4 jamba
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan "model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
), ),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -362,6 +363,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.c_fc_1", # exaone "transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w12", # neobert "transformer_encoder.{bid}.ffn.w12", # neobert
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -372,6 +374,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
), ),
MODEL_TENSOR.FFN_UP_SHEXP: ( MODEL_TENSOR.FFN_UP_SHEXP: (
@ -401,6 +404,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone "transformer.h.{bid}.mlp.c_fc_0", # exaone
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -410,6 +414,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
), ),
MODEL_TENSOR.FFN_GATE_SHEXP: ( MODEL_TENSOR.FFN_GATE_SHEXP: (
@ -448,6 +453,7 @@ class TensorNameMap:
"model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
"transformer_encoder.{bid}.ffn.w3", # neobert "transformer_encoder.{bid}.ffn.w3", # neobert
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
@ -459,6 +465,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
), ),
MODEL_TENSOR.FFN_DOWN_SHEXP: ( MODEL_TENSOR.FFN_DOWN_SHEXP: (

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum
import re import re
import logging import logging
import json import json
@ -12,6 +13,25 @@ try:
except ImportError: except ImportError:
SentencePieceProcessor = None SentencePieceProcessor = None
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.utils import (
_filter_valid_tokenizer_files,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
except ImportError:
_mistral_common_installed = False
MistralTokenizer = None
Tekkenizer = None
SentencePieceTokenizer = None
_filter_valid_tokenizer_files = None
else:
_mistral_common_installed = True
import gguf import gguf
from .gguf_writer import GGUFWriter from .gguf_writer import GGUFWriter
@ -592,3 +612,262 @@ class LlamaHfVocab(Vocab):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class MistralTokenizerType(str, Enum):
spm = "spm"
tekken = "tekken"
# Copied from Transformers (Apache 2.0)
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
def bytes_to_unicode() -> dict[int, str]:
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs_str = [chr(n) for n in cs]
return dict(zip(bs, cs_str))
class MistralVocab(Vocab):
tokenizer_model = "mistral"
name = "mistral"
added_tokens_dict: dict[str, int] = {}
added_tokens_list: list[str] = []
def __init__(self, base_path: Path):
if not _mistral_common_installed:
raise ImportError(
"To use MistralVocab, please install the `mistral-common` package. "
"You can install it with `pip install mistral-common`."
)
assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
assert MistralTokenizer is not None, "mistral_common is not installed"
assert Tekkenizer is not None, "mistral_common is not installed"
logger.info(f"Loading Mistral tokenizer from {base_path}")
# Find the tokenizer files
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
if len(valid_tokenizer_files) == 0:
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
if len(valid_tokenizer_files) > 1:
if "tekken.json" in valid_tokenizer_files:
tokenizer_file = "tekken.json"
else:
tokenizer_file = sorted(valid_tokenizer_files)[-1]
logger.warning(
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
)
else:
tokenizer_file = valid_tokenizer_files[0]
self.tokenizer = MistralTokenizer.from_file(
base_path / tokenizer_file
).instruct_tokenizer.tokenizer
self.tokenizer_type = (
MistralTokenizerType.tekken
if isinstance(self.tokenizer, Tekkenizer)
else MistralTokenizerType.spm
)
self.vocab_size = self.tokenizer.n_words
self.fname_tokenizer = base_path / tokenizer_file
self._name = (
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
)
@property
def tokenizer_name(self) -> str:
return self._name
@property
def gguf_tokenizer_model(self) -> str:
return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
assert SentencePieceTokenizer is not None, "mistral_common is not installed"
assert isinstance(self.tokenizer, SentencePieceTokenizer), (
f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
)
for i in range(self.tokenizer._model.vocab_size()):
piece = self.tokenizer._model.IdToPiece(i)
text = piece.encode("utf-8")
score: float = self.tokenizer._model.GetScore(i)
toktype = gguf.TokenType.NORMAL
if self.tokenizer._model.IsUnknown(i):
toktype = gguf.TokenType.UNKNOWN
if self.tokenizer._model.IsControl(i):
toktype = gguf.TokenType.CONTROL
if self.tokenizer._model.IsUnused(i):
toktype = gguf.TokenType.UNUSED
if self.tokenizer._model.IsByte(i):
toktype = gguf.TokenType.BYTE
yield text, score, toktype
def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
assert Tekkenizer is not None, "mistral_common is not installed"
assert isinstance(self.tokenizer, Tekkenizer), (
f"Expected Tekkenizer, got {type(self.tokenizer)}"
)
byte_encoder = bytes_to_unicode()
for token_id in range(self.tokenizer.num_special_tokens):
yield (
self.tokenizer.id_to_piece(token_id).encode("utf-8"),
0,
gguf.TokenType.CONTROL
)
for token in self.tokenizer._tekken_token2id_nospecial:
yield (
self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
0,
gguf.TokenType.NORMAL,
)
def get_token_id(self, token: str) -> int:
assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
if self.tokenizer_type == MistralTokenizerType.spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer)
return self.tokenizer._vocab.index(token)
elif self.tokenizer_type == MistralTokenizerType.tekken:
assert isinstance(self.tokenizer, Tekkenizer)
return (
self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
)
else:
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
@property
def bos_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_id(self) -> int:
return self.tokenizer.eos_id
@property
def pad_id(self) -> int:
if self.tokenizer.pad_id == -1:
return self.eos_id
return self.tokenizer.pad_id
@property
def unk_id(self) -> int:
return self.tokenizer.unk_id
@property
def bos_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
@property
def eos_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
@property
def pad_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
@property
def unk_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
if self.tokenizer_type == MistralTokenizerType.spm:
yield from self._sentencepiece_tokens()
elif self.tokenizer_type == MistralTokenizerType.tekken:
yield from self._tekken_tokens()
else:
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
@staticmethod
def token_bytes_to_string(b, byte_encoder):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
def extract_vocab_merges_from_model(self):
# Adapted from Transformers (Apache 2.0)
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
f"Expected Tekkenizer, got {type(self.tokenizer)}"
)
mergeable_ranks = self.tokenizer._model._mergeable_ranks
token_bytes_map = {
rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
}
merge_pairs = []
# Sort vocab by rank to ensure correct merge order
for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
merged_token = token_bytes_map[i]
local = []
for j in range(1, len(merged_token)):
left = merged_token[:j]
right = merged_token[j:]
if (
left in mergeable_ranks
and right in mergeable_ranks
and (left + right) in mergeable_ranks
):
local.append((left, right, i))
if not local:
raise ValueError(
f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
)
local = sorted(
local,
key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
reverse=False,
)
merge_pairs.extend(local)
merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
byte_encoder = bytes_to_unicode()
decoded_merge_pairs = [
[
self.token_bytes_to_string(val[0], byte_encoder),
self.token_bytes_to_string(val[1], byte_encoder),
]
for val in merge_pairs
]
merges = [
" ".join(
[
# ensure the spaces are properly encoded
"".join(chr(ord(c) + 256) if c == " " else c for c in part)
for part in pair
]
)
for pair in decoded_merge_pairs
]
return merges

View file

@ -3321,6 +3321,7 @@ Current version indicated by LITEVER below.
eos_ban_mode: 0, //allow the EOS token when using locally 0=auto,1=unban,2=ban,3=bypass eos_ban_mode: 0, //allow the EOS token when using locally 0=auto,1=unban,2=ban,3=bypass
token_count_multiplier: 100, //100 means 1x token_count_multiplier: 100, //100 means 1x
opmode: 4, //what mode are we in? 1=story, 2=adventure, 3=chat, 4=instruct opmode: 4, //what mode are we in? 1=story, 2=adventure, 3=chat, 4=instruct
adventure_roll_modifier: 0,
adventure_switch_mode: 0, //in adventure mode, determine story=0, action=1 or roll=2 adventure_switch_mode: 0, //in adventure mode, determine story=0, action=1 or roll=2
adventure_context_mod: true, //extra injection for adventure mode adventure_context_mod: true, //extra injection for adventure mode
fix_alpaca_leak: true, //prevents leaking when Alpaca instruct format is used on crappy models fix_alpaca_leak: true, //prevents leaking when Alpaca instruct format is used on crappy models
@ -12758,6 +12759,7 @@ Current version indicated by LITEVER below.
document.getElementById("idle_duration").value = localsettings.idle_duration; document.getElementById("idle_duration").value = localsettings.idle_duration;
document.getElementById("fix_alpaca_leak").checked = localsettings.fix_alpaca_leak; document.getElementById("fix_alpaca_leak").checked = localsettings.fix_alpaca_leak;
document.getElementById("adventure_context_mod").checked = localsettings.adventure_context_mod; document.getElementById("adventure_context_mod").checked = localsettings.adventure_context_mod;
document.getElementById("adventure_roll_modifier").value = localsettings.adventure_roll_modifier;
document.getElementById("chat_context_mod").checked = localsettings.chat_context_mod; document.getElementById("chat_context_mod").checked = localsettings.chat_context_mod;
document.getElementById("instruct_has_markdown").checked = localsettings.instruct_has_markdown; document.getElementById("instruct_has_markdown").checked = localsettings.instruct_has_markdown;
document.getElementById("instruct_has_latex").checked = localsettings.instruct_has_latex; document.getElementById("instruct_has_latex").checked = localsettings.instruct_has_latex;
@ -13231,6 +13233,7 @@ Current version indicated by LITEVER below.
localsettings.idle_duration = document.getElementById("idle_duration").value; localsettings.idle_duration = document.getElementById("idle_duration").value;
localsettings.fix_alpaca_leak = (document.getElementById("fix_alpaca_leak").checked ? true : false); localsettings.fix_alpaca_leak = (document.getElementById("fix_alpaca_leak").checked ? true : false);
localsettings.adventure_context_mod = (document.getElementById("adventure_context_mod").checked ? true : false); localsettings.adventure_context_mod = (document.getElementById("adventure_context_mod").checked ? true : false);
localsettings.adventure_roll_modifier = document.getElementById("adventure_roll_modifier").value;
localsettings.chat_context_mod = (document.getElementById("chat_context_mod").checked ? true : false); localsettings.chat_context_mod = (document.getElementById("chat_context_mod").checked ? true : false);
localsettings.instruct_has_markdown = (document.getElementById("instruct_has_markdown").checked ? true : false); localsettings.instruct_has_markdown = (document.getElementById("instruct_has_markdown").checked ? true : false);
localsettings.instruct_has_latex = (document.getElementById("instruct_has_latex").checked ? true : false); localsettings.instruct_has_latex = (document.getElementById("instruct_has_latex").checked ? true : false);
@ -15920,8 +15923,12 @@ Current version indicated by LITEVER below.
if(localsettings.adventure_switch_mode==2) if(localsettings.adventure_switch_mode==2)
{ {
let roll = Math.floor(Math.random() * 20) + 1; let roll = Math.floor(Math.random() * 20) + 1;
let modif = parseInt(localsettings.adventure_roll_modifier?localsettings.adventure_roll_modifier:0);
let modifstr = (modif>0?`+${modif}`:(modif<0?`${modif}`:""));
roll += modif;
roll = cleannum(roll,1,20);
let outcome = (roll==20?"Perfect":(roll>16?"Excellent":(roll>12?"Good":(roll>8?"Fair":(roll>4?"Poor":"Terrible"))))); let outcome = (roll==20?"Perfect":(roll>16?"Excellent":(roll>12?"Good":(roll>8?"Fair":(roll>4?"Poor":"Terrible")))));
diceaddon = ` (Rolled 1d20=${roll}/20, Outcome: ${outcome})`; diceaddon = ` (Rolled 1d20${modifstr}=${roll}/20, Outcome: ${outcome})`;
} }
newgen = "\n\n\> " + newgen + diceaddon + "\n\n"; newgen = "\n\n\> " + newgen + diceaddon + "\n\n";
} }
@ -24673,6 +24680,11 @@ Current version indicated by LITEVER below.
class="helptext">Allows using separate Instruction and Response End Tags, instead of combing them with the start tag. Don't change this halfway through a story!</span></span></div> class="helptext">Allows using separate Instruction and Response End Tags, instead of combing them with the start tag. Don't change this halfway through a story!</span></span></div>
<input type="checkbox" title="Separate End Tags" id="separate_end_tags" style="margin:0px 0px 0px auto;" onchange="toggle_separate_end_tags()"> <input type="checkbox" title="Separate End Tags" id="separate_end_tags" style="margin:0px 0px 0px auto;" onchange="toggle_separate_end_tags()">
</div> </div>
<div id="idlesection" class="settinglabel">
<div class="justifyleft settingsmall" title="Adventure Roll Modifier">Adventure Roll Modifer&nbsp;<span class="helpicon">?<span
class="helptext">Adds an integer modifier to adventure mode dice rolls.</span></span></div>
<input class="settinglabel miniinput" title="Adventure Roll Modifer" type="text" inputmode="decimal" value="0" id="adventure_roll_modifier" style="height:16px; width:30px; margin:0px 4px 0px auto;">
</div>
</div> </div>
</div> </div>

View file

@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -1933,6 +1934,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
} }
}, },
{
LLM_ARCH_SMALLTHINKER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
},
},
{ {
LLM_ARCH_DREAM, LLM_ARCH_DREAM,
{ {

View file

@ -92,6 +92,7 @@ enum llm_arch {
LLM_ARCH_SMOLLM3, LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2, LLM_ARCH_LFM2,
LLM_ARCH_DREAM, LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };

View file

@ -298,7 +298,7 @@ llama_context::llama_context(
cross.v_embd.clear(); cross.v_embd.clear();
// reserve pp graph first so that buffers are only allocated once // reserve pp (prompt processing) graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
@ -309,7 +309,7 @@ llama_context::llama_context(
n_nodes_pp = ggml_graph_n_nodes(gf); n_nodes_pp = ggml_graph_n_nodes(gf);
} }
// reserve with tg graph to get the number of splits and nodes // reserve with tg (token generation) graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
if (!gf) { if (!gf) {

View file

@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
return moe_out; return moe_out;
} }
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
ggml_tensor * cur,
ggml_tensor * probs,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llama_expert_gating_func_type gating_op,
int il) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
// add experts selection bias - introduced in DeepSeek V3
// leave probs unbiased as it's later used to get expert weights
ggml_tensor * selection_probs = probs;
if (exp_probs_b != nullptr) {
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
cb(selection_probs, "ffn_moe_probs_biased", il);
}
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
weights = ggml_soft_max(ctx0, weights);
} else {
weights = ggml_sigmoid(ctx0, weights);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
}
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * experts = nullptr;
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
cur = ggml_reglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_reglu", il);
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
assert(n_expert_used > 0);
// order the views before the adds
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
ggml_build_forward_expand(gf, cur_experts[i]);
}
// aggregate experts
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
// to avoid potentially a large number of add nodes during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
ggml_tensor * moe_out = cur_experts[0];
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
}
if (n_expert_used == 1) {
// avoid returning a non-contiguous tensor
moe_out = ggml_cont(ctx0, moe_out);
}
cb(moe_out, "ffn_moe_out", il);
return moe_out;
}
// input embeddings with optional lora // input embeddings with optional lora
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;

View file

@ -625,6 +625,18 @@ struct llm_graph_context {
llama_expert_gating_func_type gating_op, llama_expert_gating_func_type gating_op,
int il) const; int il) const;
ggml_tensor * build_moe_ffn_from_probs(
ggml_tensor * cur,
ggml_tensor * probs,
ggml_tensor * up_exps,
ggml_tensor * gate_exps,
ggml_tensor * down_exps,
ggml_tensor * exp_probs_b,
int64_t n_expert,
int64_t n_expert_used,
llama_expert_gating_func_type gating_op,
int il) const;
// //
// inputs // inputs
// //

View file

@ -2,11 +2,17 @@
#include "ggml.h" #include "ggml.h"
void llama_hparams::set_swa_pattern(uint32_t n_pattern) { void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
}
} else {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
} }
} }
}
bool llama_hparams::is_swa_any() const { bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {

View file

@ -140,7 +140,7 @@ struct llama_hparams {
// for Classifiers // for Classifiers
uint32_t n_cls_out = 1; uint32_t n_cls_out = 1;
// llama4 // llama4 smallthinker
uint32_t n_moe_layer_step = 0; uint32_t n_moe_layer_step = 0;
uint32_t n_no_rope_layer_step = 4; uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192; uint32_t n_attn_temp_floor_scale = 8192;
@ -161,9 +161,10 @@ struct llama_hparams {
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
// this value n_pattern means that every nth layer is dense (i.e. non-SWA) // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// dense_first means whether the pattern is start with a dense layer
// note that if n_pattern == 0, all layers are SWA // note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense // if n_pattern == 1, all layers are dense
// example: n_pattern = 3 // example 1: n_pattern = 3, dense_first = false
// il == 0: swa // il == 0: swa
// il == 1: swa // il == 1: swa
// il == 2: dense // il == 2: dense
@ -172,7 +173,13 @@ struct llama_hparams {
// il == 5: dense // il == 5: dense
// il == 6: swa // il == 6: swa
// etc ... // etc ...
void set_swa_pattern(uint32_t n_pattern); // example 2: n_pattern = 2, dense_first = true
// il == 0: dense
// il == 1: swa
// il == 2: dense
// il == 3: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern, bool dense_first = false);
// return true if one of the layers is SWA // return true if one of the layers is SWA
bool is_swa_any() const; bool is_swa_any() const;

View file

@ -1773,6 +1773,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_SMALLTHINKER:
{
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
if (found_swa && hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096;
hparams.set_swa_pattern(4, true);
} else {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_no_rope_layer_step = hparams.n_layer;
}
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_4B; break;
case 52: type = LLM_TYPE_20B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture"); default: throw std::runtime_error("unsupported model architecture");
} }
@ -5261,6 +5284,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
} }
} break; } break;
case LLM_ARCH_SMALLTHINKER:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER");
// MoE branch
const int64_t n_ff_exp = hparams.n_ff_exp;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -5587,6 +5646,11 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
} }
if (arch == LLM_ARCH_SMALLTHINKER) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
}
vocab.print_info(); vocab.print_info();
} }
@ -17111,6 +17175,119 @@ struct llm_build_lfm2 : public llm_graph_context {
} }
}; };
template <bool iswa>
struct llm_build_smallthinker : public llm_graph_context{
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
inp_attn_type * inp_attn = nullptr;
if constexpr (iswa) {
inp_attn = build_attn_inp_kv_unified_iswa();
} else {
inp_attn = build_attn_inp_kv_unified();
}
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
ggml_tensor * probs = nullptr;
probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens]
cb(probs, "ffn_moe_logits", il);
// norm
cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self_attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
probs = ggml_get_rows(ctx0, probs, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// MoE branch
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
nullptr, n_expert, n_expert_used,
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
cb(ffn_out, "ffn_out", il);
cur = ffn_out;
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res; llama_memory_i * res;
@ -17549,6 +17726,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_lfm2>(*this, params); llm = std::make_unique<llm_build_lfm2>(*this, params);
} break; } break;
case LLM_ARCH_SMALLTHINKER:
{
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
llm = std::make_unique<llm_build_smallthinker<true>> (*this, params);
} else {
llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
}
} break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -17747,6 +17932,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_LFM2: case LLM_ARCH_LFM2:
case LLM_ARCH_SMALLTHINKER:
return LLAMA_ROPE_TYPE_NEOX; return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL: case LLM_ARCH_QWEN2VL:

View file

@ -131,6 +131,7 @@ enum projector_type {
PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A, PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@ -150,6 +151,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
}; };
static projector_type clip_projector_type_from_string(const std::string & str) { static projector_type clip_projector_type_from_string(const std::string & str) {

View file

@ -372,6 +372,16 @@ struct clip_model {
ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr; ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr; ggml_tensor * mm_norm_mid_w = nullptr;
bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
bool audio_has_stack_frames() const {
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
}; };
bool enable_gpu_clip = true; bool enable_gpu_clip = true;
@ -1508,10 +1518,9 @@ struct clip_graph {
cb(cur, "after_transformer", -1); cb(cur, "after_transformer", -1);
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) { if (model.audio_has_stack_frames()) {
// StackAudioFrames // StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor; int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur); int64_t pad = padded_len - ggml_nelements(cur);
@ -1521,12 +1530,11 @@ struct clip_graph {
} }
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0); ggml_row_size(cur->type, stride), 0);
cb(cur, "after_stacked", -1);
} }
cb(cur, "after_stacked", -1); if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
// UltravoxProjector // UltravoxProjector
{
// pre-norm // pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6); cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
@ -1544,13 +1552,18 @@ struct clip_graph {
// ffn out // ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) { } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
// projector // projector
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur); cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
cur = ggml_add(ctx0, cur, model.mm_fc_b); cur = ggml_add(ctx0, cur, model.mm_fc_b);
} else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
} else { } else {
GGML_ABORT("%s: unknown projector type", __func__); GGML_ABORT("%s: unknown projector type", __func__);
} }
@ -1695,8 +1708,7 @@ private:
inpL = cur; inpL = cur;
} }
// TODO @ngxson : find a way to move this outside if (ctx->model.audio_has_avgpool()) {
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
ggml_tensor * cur = inpL; ggml_tensor * cur = inpL;
cur = ggml_transpose(ctx0, cur); cur = ggml_transpose(ctx0, cur);
cur = ggml_cont(ctx0, cur); cur = ggml_cont(ctx0, cur);
@ -2010,6 +2022,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
res = graph.build_llama4(); res = graph.build_llama4();
} break; } break;
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
{ {
res = graph.build_whisper_enc(); res = graph.build_whisper_enc();
@ -2310,8 +2323,10 @@ struct clip_model_loader {
} break; } break;
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_VOXTRAL:
{ {
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX; bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
model.proj_type == PROJECTOR_TYPE_VOXTRAL;
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
if (hparams.n_mel_bins != 128) { if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
@ -2595,6 +2610,15 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight")); model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias")); model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
} break; } break;
case PROJECTOR_TYPE_VOXTRAL:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
{ {
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@ -3746,17 +3770,26 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int scale_factor = ctx->model.hparams.proj_scale_factor; int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches_sq /= (scale_factor * scale_factor); n_patches_sq /= (scale_factor * scale_factor);
} break; } break;
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
{
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor / 2;
} break;
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
{ {
// divide by 2 because of whisper n_patches_sq = img->nx;
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq = img->nx / 4; const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
if (ctx->model.audio_has_stack_frames()) {
GGML_ASSERT(proj_stack_factor > 0);
const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor;
}
// whisper downscales input token by half after conv1d
n_patches_sq /= 2;
if (ctx->model.audio_has_avgpool()) {
// divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq /= 2;
}
} break; } break;
default: default:
GGML_ABORT("unsupported projector type"); GGML_ABORT("unsupported projector type");
@ -4162,6 +4195,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
{ {
// do nothing // do nothing
} break; } break;
@ -4442,6 +4476,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_IDEFICS3:
return ctx->model.projection->ne[1]; return ctx->model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
return ctx->model.mm_2_w->ne[1]; return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_INTERNVL:
return ctx->model.mm_3_w->ne[1]; return ctx->model.mm_3_w->ne[1];
@ -4492,7 +4527,8 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A; || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
} }
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {

View file

@ -289,6 +289,10 @@ struct mtmd_context {
aud_beg = "<|audio_bos|>"; aud_beg = "<|audio_bos|>";
aud_end = "<|audio_eos|>"; aud_end = "<|audio_eos|>";
} else if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// [BEGIN_AUDIO] ... (embeddings) ...
aud_beg = "[BEGIN_AUDIO]";
} }
} }

View file

@ -1,5 +1,5 @@
-r ../../requirements/requirements-convert_legacy_llama.txt -r ../../requirements/requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
pillow~=10.2.0 pillow~=11.3.0
torch~=2.2.1 torch~=2.2.1
torchvision~=0.17.1 torchvision~=0.17.1

View file

@ -71,6 +71,7 @@ add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big # to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then if [ "$RUN_BIG_TESTS" = true ]; then

View file

@ -1,18 +1,25 @@
# quantize # quantize
This tool takes a GGUF input model file, typically in a high-precision format like F32 or BF16, and converts it to a quantized format.
Quantization reduces the precision of model weights (e.g., from 32-bit floats to 4-bit integers), which shrinks the model's size and can speed up inference.
This process however, may introduce some accuracy loss which is usually measured in [Perplexity](https://huggingface.co/docs/transformers/en/perplexity) (ppl) and/or [KullbackLeibler Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) (kld).
This can be minimized by using a suitable imatrix file.
You can also use the [GGUF-my-repo](https://huggingface.co/spaces/ggml-org/gguf-my-repo) space on Hugging Face to build your own quants without any setup. You can also use the [GGUF-my-repo](https://huggingface.co/spaces/ggml-org/gguf-my-repo) space on Hugging Face to build your own quants without any setup.
Note: It is synced from llama.cpp `main` every 6 hours. Note: It is synced from llama.cpp `main` every 6 hours.
Example usage: Example usage:
```./llama-quantize [options] input-model-f32.gguf [output-model-quant.gguf] type [threads]```
```bash ```bash
# obtain the official LLaMA model weights and place them in ./models # from Hugginface, obtain the official meta-llama/Llama-3.1-8B model weights and place them in ./models
ls ./models ls ./models
llama-2-7b tokenizer_checklist.chk tokenizer.model config.json model-00001-of-00004.safetensors model-00004-of-00004.safetensors README.md tokenizer.json
# [Optional] for models using BPE tokenizers generation_config.json model-00002-of-00004.safetensors model.safetensors.index.json special_tokens_map.json USE_POLICY.md
ls ./models LICENSE model-00003-of-00004.safetensors original tokenizer_config.json
<folder containing weights and tokenizer json> vocab.json
# [Optional] for PyTorch .bin models like Mistral-7B # [Optional] for PyTorch .bin models like Mistral-7B
ls ./models ls ./models
<folder containing weights and tokenizer json> <folder containing weights and tokenizer json>
@ -21,7 +28,7 @@ ls ./models
python3 -m pip install -r requirements.txt python3 -m pip install -r requirements.txt
# convert the model to ggml FP16 format # convert the model to ggml FP16 format
python3 convert_hf_to_gguf.py models/mymodel/ python3 convert_hf_to_gguf.py ./models/mymodel/
# quantize the model to 4-bits (using Q4_K_M method) # quantize the model to 4-bits (using Q4_K_M method)
./llama-quantize ./models/mymodel/ggml-model-f16.gguf ./models/mymodel/ggml-model-Q4_K_M.gguf Q4_K_M ./llama-quantize ./models/mymodel/ggml-model-f16.gguf ./models/mymodel/ggml-model-Q4_K_M.gguf Q4_K_M
@ -37,40 +44,117 @@ Run the quantized model:
./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant" ./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant"
``` ```
When running the larger models, make sure you have enough disk space to store all the intermediate files. Options:
* `--allow-requantize` allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit
* `--leave-output-tensor` will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing
* `--pure` disables k-quant mixtures and quantizes all tensors to the same type
* `--imatrix` uses data in file generated by `llama-imatrix` as importance matrix for quant optimizations (highly recommended)
* `--include-weights` use an importance matrix for tensor(s) in the list. Cannot be used with `--exclude-weights`
* `--exclude-weights` use an importance matrix for tensor(s) in the list. Cannot be used with `--include-weights`
* `--output-tensor-type` use a specific quant type for the output.weight tensor
* `--token-embedding-type` use a specific quant type for the token embeddings tensor
* `--keep-split` will generate the quantized model in the same shards as the input file otherwise it will produce a single quantized file
Advanced options:
* `--tensor-type` quantize specific tensor(s) to specific quant types. Supports regex syntax. May be specified multiple times.
* `--prune-layers` prune (remove) the layers in the list
* `--override-kv` option to override model metadata by key in the quantized model. May be specified multiple times
Examples:
```bash
# naive Q4_K_M quantization using default settings and 8 CPU threads. Output will be "ggml-model-Q4_K_M.gguf"
./llama-quantize input-model-f32.gguf q4_k_m 8
```
```bash
# quantize model enabling re-quantization, leaving the output tensor unquantized and all others quantized at the same level (Q4_K)
./llama-quantize --allow-requantize --leave-output-tensor --pure input-model-f32.gguf q4_k_m 8
```
```bash
# quantize model using an importance matrix for specified tensors only (attn_v and ffn_down)
./llama-quantize --imatrix imatrix.gguf --include-weights attn_v --include-weights ffn_down input-model-f32.gguf q4_k_m 8
```
```bash
# quantize model setting output tensor to Q5_K_M, token embeddings to Q3_K_M, and keeping the input file's shards
./llama-quantize --imatrix imatrix.gguf --output-tensor-type q5_k --token-embedding-type q3_k --keep-split input-model-f32.gguf q4_k_m 8
```
```bash
# quantize model using a regex to quantize attn_k tensors in odd layers to Q5_K_M and attn_q tensors in even layers to Q3_K_M
./llama-quantize --imatrix imatrix.gguf --tensor-type "\.(\d*[13579])\.attn_k=q5_k" --tensor-type "\.(\d*[02468])\.attn_q=q3_k" input-model-f32.gguf q4_k_m 8
```
```bash
# quantize model setting tensors attn_v and ffn_down to Q5_K_M and pruning layers 20, 21, and 22
./llama-quantize --imatrix imatrix.gguf --tensor-type attn_v=q5_k --tensor-type ffn_down=q5_k --prune-layers 20,21,22 input-model-f32.gguf q4_k_m 8
```
```bash
# override expert used count metadata to 16, prune layers 20, 21, and 22 without quantizing the model (copy tensors) and use specified name for the output file
./llama-quantize --imatrix imatrix.gguf --override-kv qwen3moe.expert_used_count=int:16 --prune-layers 20,21,22 input-model-f32.gguf pruned-model-f32.gguf copy 8
```
## Memory/Disk Requirements ## Memory/Disk Requirements
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. When running the larger models, make sure you have enough disk space to store all the intermediate files.
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. For exmaple (Llama 3.1):
| Model | Original size | Quantized size (Q4_K_M) |
| ----: | ------------: | ----------------------: |
| 8B | 32.1 GB | 4.9 GB |
| 70B | 280.9 GB | 43.1 GB |
| 405B | 1,625.1 GB | 249.1 GB |
| Model | Original size | Quantized size (Q4_0) |
|------:|--------------:|----------------------:|
| 7B | 13 GB | 3.9 GB |
| 13B | 24 GB | 7.8 GB |
| 30B | 60 GB | 19.5 GB |
| 65B | 120 GB | 38.5 GB |
## Quantization ## Quantization
Several quantization methods are supported. They differ in the resulting model disk size and inference speed. Several quantization methods are supported. They differ in the resulting model disk size and inference speed. For example,
*(outdated)* ### [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B)
| Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | | Measure | IQ1_S | IQ1_M | IQ2_XXS | IQ2_XS | IQ2_S | IQ2_M |
|------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:| | --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ |
| 7B | perplexity | 5.9066 | 6.1565 | 6.0912 | 5.9862 | 5.9481 | 5.9070 | | bits/weight | 2.0042 | 2.1460 | 2.3824 | 2.5882 | 2.7403 | 2.9294 |
| 7B | file size | 13.0G | 3.5G | 3.9G | 4.3G | 4.7G | 6.7G | | size (GiB) | 1.87 | 2.01 | 2.23 | 2.42 | 2.56 | 2.74 |
| 7B | ms/tok @ 4th | 127 | 55 | 54 | 76 | 83 | 72 | | prompt processing t/s @ 512 | 858.88 ±1.22 | 847.99 ±0.47 | 852.39 ±0.85 | 826.99 ±12.51 | 783.55 ±13.73 | 787.68 ±7.00 |
| 7B | ms/tok @ 8th | 122 | 43 | 45 | 52 | 56 | 67 | | text generation t/s @ 128 | 79.73 ±0.79 | 72.92 ±0.14 | 79.86 ±0.22 | 78.04 ±0.46 | 77.30 ±2.47 | 74.44 ±0.15 |
| 7B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 |
| 13B | perplexity | 5.2543 | 5.3860 | 5.3608 | 5.2856 | 5.2706 | 5.2548 | | Measure | IQ3_XXS | IQ3_XS | IQ3_S | IQ3_M | IQ4_XS | IQ4_NL |
| 13B | file size | 25.0G | 6.8G | 7.6G | 8.3G | 9.1G | 13G | | --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ |
| 13B | ms/tok @ 4th | - | 103 | 105 | 148 | 160 | 131 | | bits/weight | 3.2548 | 3.4977 | 3.6606 | 3.7628 | 4.4597 | 4.6818 |
| 13B | ms/tok @ 8th | - | 73 | 82 | 98 | 105 | 128 | | size (GiB) | 3.04 | 3.27 | 3.42 | 3.52 | 4.17 | 4.38 |
| 13B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 | | prompt processing t/s @ 512 | 813.88 ±6.53 | 708.71 ±1.26 | 798.78 ±8.81 | 768.70 ±13.73 | 771.80 ±11.38 | 806.03 ±7.07 |
| text generation t/s @ 128 | 73.95 ±0.20 | 71.67 ±0.54 | 69.31 ±0.63 | 70.15 ±0.33 | 77.51 ±0.20 | 76.63 ±0.28 |
| Measure | Q2_K_S | Q2_K | Q3_K_S | Q3_K_M | Q3_K_L | Q4_K_S |
| --------------------------- | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ |
| bits/weight | 2.9697 | 3.1593 | 3.6429 | 3.9960 | 4.2979 | 4.6672 |
| size (GiB) | 2.78 | 2.95 | 3.41 | 3.74 | 4.02 | 4.36 |
| prompt processing t/s @ 512 | 798.91 ±6.40 | 784.45 ±7.85 | 752.17 ±7.94 | 783.44 ±9.92 | 761.17 ±7.55 | 818.55 ±9.58 |
| text generation t/s @ 128 | 90.01 ±0.12 | 79.85 ±0.20 | 69.84 ±0.18 | 71.68 ±0.22 | 69.38 ±0.49 | 76.71 ±0.20 |
| Measure | Q4_K_S | Q4_K_M | Q5_K_S | Q5_K_M | Q6_K | Q8_0 |
| --------------------------- | ------------ | ------------- | ------------ | ------------ | ------------- | ------------ |
| bits/weight | 4.6672 | 4.8944 | 5.5704 | 5.7036 | 6.5633 | 8.5008 |
| size (GiB) | 4.36 | 4.58 | 5.21 | 5.33 | 6.14 | 7.95 |
| prompt processing t/s @ 512 | 818.55 ±9.58 | 821.81 ±21.44 | 752.52 ±0.99 | 758.69 ±7.43 | 812.01 ±10.82 | 865.09 ±8.30 |
| text generation t/s @ 128 | 76.71 ±0.20 | 71.93 ±1.52 | 69.53 ±0.18 | 67.23 ±1.08 | 58.67 ±3.13 | 50.93 ±0.08 |
| Measure | F16 |
| --------------------------- | ------------ |
| bits/weight | 16.0005 |
| size (GiB) | 14.96 |
| prompt processing t/s @ 512 | 923.49 ±0.53 |
| text generation t/s @ 128 | 29.17 ±0.04 |
## Background information on llama-quantize
- [k-quants](https://github.com/ggml-org/llama.cpp/pull/1684) - [k-quants](https://github.com/ggml-org/llama.cpp/pull/1684)
- recent k-quants improvements and new i-quants - k-quants improvements and i-quants
- [#2707](https://github.com/ggml-org/llama.cpp/pull/2707) - [#2707](https://github.com/ggml-org/llama.cpp/pull/2707)
- [#2807](https://github.com/ggml-org/llama.cpp/pull/2807) - [#2807](https://github.com/ggml-org/llama.cpp/pull/2807)
- [#4773 - 2-bit i-quants (inference)](https://github.com/ggml-org/llama.cpp/pull/4773) - [#4773 - 2-bit i-quants (inference)](https://github.com/ggml-org/llama.cpp/pull/4773)
@ -85,45 +169,3 @@ Several quantization methods are supported. They differ in the resulting model d
- [#5060 - Q3_K_XS](https://github.com/ggml-org/llama.cpp/pull/5060) - [#5060 - Q3_K_XS](https://github.com/ggml-org/llama.cpp/pull/5060)
- [#5196 - 3-bit i-quants](https://github.com/ggml-org/llama.cpp/pull/5196) - [#5196 - 3-bit i-quants](https://github.com/ggml-org/llama.cpp/pull/5196)
- [quantization tuning](https://github.com/ggml-org/llama.cpp/pull/5320), [another one](https://github.com/ggml-org/llama.cpp/pull/5334), and [another one](https://github.com/ggml-org/llama.cpp/pull/5361) - [quantization tuning](https://github.com/ggml-org/llama.cpp/pull/5320), [another one](https://github.com/ggml-org/llama.cpp/pull/5334), and [another one](https://github.com/ggml-org/llama.cpp/pull/5361)
**Llama 2 7B**
| Quantization | Bits per Weight (BPW) |
|--------------|-----------------------|
| Q2_K | 3.35 |
| Q3_K_S | 3.50 |
| Q3_K_M | 3.91 |
| Q3_K_L | 4.27 |
| Q4_K_S | 4.58 |
| Q4_K_M | 4.84 |
| Q5_K_S | 5.52 |
| Q5_K_M | 5.68 |
| Q6_K | 6.56 |
**Llama 2 13B**
Quantization | Bits per Weight (BPW)
-- | --
Q2_K | 3.34
Q3_K_S | 3.48
Q3_K_M | 3.89
Q3_K_L | 4.26
Q4_K_S | 4.56
Q4_K_M | 4.83
Q5_K_S | 5.51
Q5_K_M | 5.67
Q6_K | 6.56
**Llama 2 70B**
Quantization | Bits per Weight (BPW)
-- | --
Q2_K | 3.40
Q3_K_S | 3.47
Q3_K_M | 3.85
Q3_K_L | 4.19
Q4_K_S | 4.53
Q4_K_M | 4.80
Q5_K_S | 5.50
Q5_K_M | 5.65
Q6_K | 6.56