mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-09 16:44:35 +00:00
merge but voxtral not working
This commit is contained in:
commit
b8425f5a9c
23 changed files with 1040 additions and 140 deletions
|
@ -1900,6 +1900,7 @@ class StableLMModel(TextModel):
|
|||
"MixtralForCausalLM",
|
||||
"VLlama3ForCausalLM",
|
||||
"LlavaForConditionalGeneration",
|
||||
"VoxtralForConditionalGeneration",
|
||||
"LlamaModel")
|
||||
class LlamaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
@ -1912,6 +1913,11 @@ class LlamaModel(TextModel):
|
|||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def set_vocab(self):
|
||||
path_tekken_json = self.dir_model / "tekken.json"
|
||||
path_tokenizer_json = self.dir_model / "tokenizer.json"
|
||||
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
|
||||
return self.set_vocab_tekken()
|
||||
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
|
@ -1944,6 +1950,52 @@ class LlamaModel(TextModel):
|
|||
if self.hparams.get("vocab_size", 32000) == 49152:
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
|
||||
def set_vocab_tekken(self):
|
||||
vocab = gguf.vocab.MistralVocab(self.dir_model)
|
||||
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
|
||||
|
||||
tokens = []
|
||||
scores = []
|
||||
toktypes = []
|
||||
|
||||
for text, score, toktype in vocab.all_tokens():
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
assert len(tokens) == vocab.vocab_size, (
|
||||
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
|
||||
)
|
||||
|
||||
if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
|
||||
self.gguf_writer.add_tokenizer_pre("tekken")
|
||||
self.gguf_writer.add_token_merges(
|
||||
vocab.extract_vocab_merges_from_model()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
|
||||
)
|
||||
|
||||
self.gguf_writer.add_bos_token_id(vocab.bos_id)
|
||||
self.gguf_writer.add_eos_token_id(vocab.eos_id)
|
||||
self.gguf_writer.add_unk_token_id(vocab.unk_id)
|
||||
self.gguf_writer.add_pad_token_id(vocab.pad_id)
|
||||
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
self.gguf_writer.add_vocab_size(vocab.vocab_size)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(False)
|
||||
|
||||
script_dir = Path(__file__).parent
|
||||
template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
self.gguf_writer.add_chat_template(template)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
@ -1971,12 +2023,13 @@ class LlamaModel(TextModel):
|
|||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
is_vision_tensor = "vision_tower" in name \
|
||||
is_multimodal_tensor = "vision_tower" in name \
|
||||
or "vision_model" in name \
|
||||
or "audio_tower" in name \
|
||||
or "model.connector" in name \
|
||||
or "multi_modal_projector" in name
|
||||
|
||||
if is_vision_tensor:
|
||||
if is_multimodal_tensor:
|
||||
return [] # skip vision tensors
|
||||
elif self.hf_arch == "LlamaModel":
|
||||
name = "model." + name
|
||||
|
@ -7231,9 +7284,10 @@ class WhisperEncoderModel(MmprojModel):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["hidden_size"] = self.hparams["d_model"]
|
||||
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
|
||||
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
|
||||
if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
|
||||
self.hparams["hidden_size"] = self.hparams["d_model"]
|
||||
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
|
||||
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
@ -7272,9 +7326,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
|
||||
@ModelBase.register("VoxtralForConditionalGeneration")
|
||||
class VoxtralWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VOXTRAL)
|
||||
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
|
||||
|
||||
|
||||
@ModelBase.register("FalconH1ForCausalLM")
|
||||
class FalconH1Model(Mamba2Model):
|
||||
model_arch = gguf.MODEL_ARCH.FALCON_H1
|
||||
|
@ -7589,6 +7655,88 @@ class LFM2Model(TextModel):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("SmallThinkerForCausalLM")
|
||||
class SmallThinkerModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
self.gguf_writer.add_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
if (self.hparams.get('moe_primary_router_apply_softmax')):
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
# YaRN is not enabled by default
|
||||
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
sliding_window_layout = self.hparams.get("sliding_window_layout")
|
||||
if sliding_window_layout:
|
||||
for i in sliding_window_layout:
|
||||
if i != 0:
|
||||
sliding_window = self.hparams.get("sliding_window_size")
|
||||
if sliding_window:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
break
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down", "gate", "up"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
|
@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
}
|
||||
}
|
||||
|
||||
K += gridDim.y*D * nb11;
|
||||
V += gridDim.y*D * nb21;
|
||||
maskh += gridDim.y*D;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
|
|
@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
K += blockIdx.y*D * nb11;
|
||||
V += blockIdx.y*D * nb21;
|
||||
maskh += blockIdx.y*D;
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
||||
// Increment pointers after each loop:
|
||||
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
||||
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
if (mask) {
|
||||
|
@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
}
|
||||
}
|
||||
|
||||
K += gridDim.y*D * nb11;
|
||||
V += gridDim.y*D * nb21;
|
||||
maskh += gridDim.y*D;
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
|
|
@ -500,6 +500,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_conv2d_f32;
|
||||
vk_pipeline pipeline_conv2d_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||
|
||||
|
@ -3090,12 +3091,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
||||
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
||||
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
|
||||
ggml_vk_create_pipeline(
|
||||
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
|
||||
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
||||
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(
|
||||
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
|
||||
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
||||
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
|
||||
false);
|
||||
ggml_vk_create_pipeline(
|
||||
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
|
||||
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
|
||||
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
|
||||
false);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
@ -6982,9 +6992,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
return ctx->device->pipeline_conv2d_f32;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_conv2d_f32;
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_conv2d_f16_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
|
@ -7906,6 +7920,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
// Skip empty skip_rows operations. For most ops the empty check at the start
|
||||
// of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
|
||||
// with empty srcs.
|
||||
if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
|
@ -8202,13 +8223,13 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
|
||||
static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
|
@ -10891,7 +10912,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE;
|
||||
// Channel-contiguous format is not supported yet.
|
||||
return (op->src[0]->type == GGML_TYPE_F32 &&
|
||||
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
|
|
|
@ -670,6 +670,7 @@ void process_shaders() {
|
|||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
|
||||
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
|
|
@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum):
|
|||
SMOLLM3 = auto()
|
||||
LFM2 = auto()
|
||||
DREAM = auto()
|
||||
SMALLTHINKER = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
@ -695,6 +696,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
@ -2483,6 +2485,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
],
|
||||
MODEL_ARCH.SMALLTHINKER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -2704,6 +2724,7 @@ class VisionProjectorType:
|
|||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
VOXTRAL = "voxtral"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
|
|
@ -317,6 +317,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feed_forward.router", # llama4 jamba
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
@ -362,6 +363,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid
|
||||
"transformer_encoder.{bid}.ffn.w12", # neobert
|
||||
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
@ -372,6 +374,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
|
@ -401,6 +404,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
@ -410,6 +414,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) ernie4.5-moe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
|
@ -448,6 +453,7 @@ class TensorNameMap:
|
|||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid
|
||||
"transformer_encoder.{bid}.ffn.w3", # neobert
|
||||
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
@ -459,6 +465,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
|
@ -12,6 +13,25 @@ try:
|
|||
except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.utils import (
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
_mistral_common_installed = False
|
||||
MistralTokenizer = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_filter_valid_tokenizer_files = None
|
||||
else:
|
||||
_mistral_common_installed = True
|
||||
|
||||
|
||||
import gguf
|
||||
|
||||
from .gguf_writer import GGUFWriter
|
||||
|
@ -592,3 +612,262 @@ class LlamaHfVocab(Vocab):
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
class MistralTokenizerType(str, Enum):
|
||||
spm = "spm"
|
||||
tekken = "tekken"
|
||||
|
||||
|
||||
# Copied from Transformers (Apache 2.0)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
|
||||
|
||||
def bytes_to_unicode() -> dict[int, str]:
|
||||
"""
|
||||
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
||||
characters the bpe code barfs on.
|
||||
|
||||
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
||||
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
||||
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
||||
tables between utf-8 bytes and unicode strings.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs_str = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs_str))
|
||||
|
||||
|
||||
class MistralVocab(Vocab):
|
||||
tokenizer_model = "mistral"
|
||||
name = "mistral"
|
||||
|
||||
added_tokens_dict: dict[str, int] = {}
|
||||
added_tokens_list: list[str] = []
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(
|
||||
"To use MistralVocab, please install the `mistral-common` package. "
|
||||
"You can install it with `pip install mistral-common`."
|
||||
)
|
||||
assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
|
||||
assert MistralTokenizer is not None, "mistral_common is not installed"
|
||||
assert Tekkenizer is not None, "mistral_common is not installed"
|
||||
|
||||
logger.info(f"Loading Mistral tokenizer from {base_path}")
|
||||
|
||||
# Find the tokenizer files
|
||||
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
|
||||
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
|
||||
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
|
||||
if len(valid_tokenizer_files) > 1:
|
||||
if "tekken.json" in valid_tokenizer_files:
|
||||
tokenizer_file = "tekken.json"
|
||||
else:
|
||||
tokenizer_file = sorted(valid_tokenizer_files)[-1]
|
||||
logger.warning(
|
||||
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
|
||||
)
|
||||
else:
|
||||
tokenizer_file = valid_tokenizer_files[0]
|
||||
|
||||
self.tokenizer = MistralTokenizer.from_file(
|
||||
base_path / tokenizer_file
|
||||
).instruct_tokenizer.tokenizer
|
||||
self.tokenizer_type = (
|
||||
MistralTokenizerType.tekken
|
||||
if isinstance(self.tokenizer, Tekkenizer)
|
||||
else MistralTokenizerType.spm
|
||||
)
|
||||
self.vocab_size = self.tokenizer.n_words
|
||||
self.fname_tokenizer = base_path / tokenizer_file
|
||||
self._name = (
|
||||
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def gguf_tokenizer_model(self) -> str:
|
||||
return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
|
||||
|
||||
def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
assert SentencePieceTokenizer is not None, "mistral_common is not installed"
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), (
|
||||
f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
|
||||
for i in range(self.tokenizer._model.vocab_size()):
|
||||
piece = self.tokenizer._model.IdToPiece(i)
|
||||
text = piece.encode("utf-8")
|
||||
score: float = self.tokenizer._model.GetScore(i)
|
||||
|
||||
toktype = gguf.TokenType.NORMAL
|
||||
if self.tokenizer._model.IsUnknown(i):
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
if self.tokenizer._model.IsControl(i):
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
|
||||
if self.tokenizer._model.IsUnused(i):
|
||||
toktype = gguf.TokenType.UNUSED
|
||||
if self.tokenizer._model.IsByte(i):
|
||||
toktype = gguf.TokenType.BYTE
|
||||
|
||||
yield text, score, toktype
|
||||
|
||||
def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
assert Tekkenizer is not None, "mistral_common is not installed"
|
||||
assert isinstance(self.tokenizer, Tekkenizer), (
|
||||
f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
for token_id in range(self.tokenizer.num_special_tokens):
|
||||
yield (
|
||||
self.tokenizer.id_to_piece(token_id).encode("utf-8"),
|
||||
0,
|
||||
gguf.TokenType.CONTROL
|
||||
)
|
||||
for token in self.tokenizer._tekken_token2id_nospecial:
|
||||
yield (
|
||||
self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
|
||||
0,
|
||||
gguf.TokenType.NORMAL,
|
||||
)
|
||||
|
||||
def get_token_id(self, token: str) -> int:
|
||||
assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
|
||||
if self.tokenizer_type == MistralTokenizerType.spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer)
|
||||
return self.tokenizer._vocab.index(token)
|
||||
elif self.tokenizer_type == MistralTokenizerType.tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer)
|
||||
return (
|
||||
self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
||||
|
||||
@property
|
||||
def bos_id(self) -> int:
|
||||
return self.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
if self.tokenizer.pad_id == -1:
|
||||
return self.eos_id
|
||||
return self.tokenizer.pad_id
|
||||
|
||||
@property
|
||||
def unk_id(self) -> int:
|
||||
return self.tokenizer.unk_id
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
|
||||
if self.tokenizer_type == MistralTokenizerType.spm:
|
||||
yield from self._sentencepiece_tokens()
|
||||
|
||||
elif self.tokenizer_type == MistralTokenizerType.tekken:
|
||||
yield from self._tekken_tokens()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
|
||||
|
||||
@staticmethod
|
||||
def token_bytes_to_string(b, byte_encoder):
|
||||
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||
|
||||
def extract_vocab_merges_from_model(self):
|
||||
# Adapted from Transformers (Apache 2.0)
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
|
||||
assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
|
||||
f"Expected Tekkenizer, got {type(self.tokenizer)}"
|
||||
)
|
||||
mergeable_ranks = self.tokenizer._model._mergeable_ranks
|
||||
token_bytes_map = {
|
||||
rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
|
||||
}
|
||||
merge_pairs = []
|
||||
|
||||
# Sort vocab by rank to ensure correct merge order
|
||||
for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
|
||||
merged_token = token_bytes_map[i]
|
||||
local = []
|
||||
for j in range(1, len(merged_token)):
|
||||
left = merged_token[:j]
|
||||
right = merged_token[j:]
|
||||
if (
|
||||
left in mergeable_ranks
|
||||
and right in mergeable_ranks
|
||||
and (left + right) in mergeable_ranks
|
||||
):
|
||||
local.append((left, right, i))
|
||||
if not local:
|
||||
raise ValueError(
|
||||
f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
|
||||
)
|
||||
local = sorted(
|
||||
local,
|
||||
key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
|
||||
reverse=False,
|
||||
)
|
||||
merge_pairs.extend(local)
|
||||
merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
|
||||
decoded_merge_pairs = [
|
||||
[
|
||||
self.token_bytes_to_string(val[0], byte_encoder),
|
||||
self.token_bytes_to_string(val[1], byte_encoder),
|
||||
]
|
||||
for val in merge_pairs
|
||||
]
|
||||
|
||||
merges = [
|
||||
" ".join(
|
||||
[
|
||||
# ensure the spaces are properly encoded
|
||||
"".join(chr(ord(c) + 256) if c == " " else c for c in part)
|
||||
for part in pair
|
||||
]
|
||||
)
|
||||
for pair in decoded_merge_pairs
|
||||
]
|
||||
|
||||
return merges
|
||||
|
|
14
klite.embd
14
klite.embd
|
@ -3321,6 +3321,7 @@ Current version indicated by LITEVER below.
|
|||
eos_ban_mode: 0, //allow the EOS token when using locally 0=auto,1=unban,2=ban,3=bypass
|
||||
token_count_multiplier: 100, //100 means 1x
|
||||
opmode: 4, //what mode are we in? 1=story, 2=adventure, 3=chat, 4=instruct
|
||||
adventure_roll_modifier: 0,
|
||||
adventure_switch_mode: 0, //in adventure mode, determine story=0, action=1 or roll=2
|
||||
adventure_context_mod: true, //extra injection for adventure mode
|
||||
fix_alpaca_leak: true, //prevents leaking when Alpaca instruct format is used on crappy models
|
||||
|
@ -12758,6 +12759,7 @@ Current version indicated by LITEVER below.
|
|||
document.getElementById("idle_duration").value = localsettings.idle_duration;
|
||||
document.getElementById("fix_alpaca_leak").checked = localsettings.fix_alpaca_leak;
|
||||
document.getElementById("adventure_context_mod").checked = localsettings.adventure_context_mod;
|
||||
document.getElementById("adventure_roll_modifier").value = localsettings.adventure_roll_modifier;
|
||||
document.getElementById("chat_context_mod").checked = localsettings.chat_context_mod;
|
||||
document.getElementById("instruct_has_markdown").checked = localsettings.instruct_has_markdown;
|
||||
document.getElementById("instruct_has_latex").checked = localsettings.instruct_has_latex;
|
||||
|
@ -13231,6 +13233,7 @@ Current version indicated by LITEVER below.
|
|||
localsettings.idle_duration = document.getElementById("idle_duration").value;
|
||||
localsettings.fix_alpaca_leak = (document.getElementById("fix_alpaca_leak").checked ? true : false);
|
||||
localsettings.adventure_context_mod = (document.getElementById("adventure_context_mod").checked ? true : false);
|
||||
localsettings.adventure_roll_modifier = document.getElementById("adventure_roll_modifier").value;
|
||||
localsettings.chat_context_mod = (document.getElementById("chat_context_mod").checked ? true : false);
|
||||
localsettings.instruct_has_markdown = (document.getElementById("instruct_has_markdown").checked ? true : false);
|
||||
localsettings.instruct_has_latex = (document.getElementById("instruct_has_latex").checked ? true : false);
|
||||
|
@ -15920,8 +15923,12 @@ Current version indicated by LITEVER below.
|
|||
if(localsettings.adventure_switch_mode==2)
|
||||
{
|
||||
let roll = Math.floor(Math.random() * 20) + 1;
|
||||
let modif = parseInt(localsettings.adventure_roll_modifier?localsettings.adventure_roll_modifier:0);
|
||||
let modifstr = (modif>0?`+${modif}`:(modif<0?`${modif}`:""));
|
||||
roll += modif;
|
||||
roll = cleannum(roll,1,20);
|
||||
let outcome = (roll==20?"Perfect":(roll>16?"Excellent":(roll>12?"Good":(roll>8?"Fair":(roll>4?"Poor":"Terrible")))));
|
||||
diceaddon = ` (Rolled 1d20=${roll}/20, Outcome: ${outcome})`;
|
||||
diceaddon = ` (Rolled 1d20${modifstr}=${roll}/20, Outcome: ${outcome})`;
|
||||
}
|
||||
newgen = "\n\n\> " + newgen + diceaddon + "\n\n";
|
||||
}
|
||||
|
@ -24673,6 +24680,11 @@ Current version indicated by LITEVER below.
|
|||
class="helptext">Allows using separate Instruction and Response End Tags, instead of combing them with the start tag. Don't change this halfway through a story!</span></span></div>
|
||||
<input type="checkbox" title="Separate End Tags" id="separate_end_tags" style="margin:0px 0px 0px auto;" onchange="toggle_separate_end_tags()">
|
||||
</div>
|
||||
<div id="idlesection" class="settinglabel">
|
||||
<div class="justifyleft settingsmall" title="Adventure Roll Modifier">Adventure Roll Modifer <span class="helpicon">?<span
|
||||
class="helptext">Adds an integer modifier to adventure mode dice rolls.</span></span></div>
|
||||
<input class="settinglabel miniinput" title="Adventure Roll Modifer" type="text" inputmode="decimal" value="0" id="adventure_roll_modifier" style="height:16px; width:30px; margin:0px 4px 0px auto;">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -1933,6 +1934,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DREAM,
|
||||
{
|
||||
|
|
|
@ -92,6 +92,7 @@ enum llm_arch {
|
|||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
|
@ -298,7 +298,7 @@ llama_context::llama_context(
|
|||
|
||||
cross.v_embd.clear();
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
|
@ -309,7 +309,7 @@ llama_context::llama_context(
|
|||
n_nodes_pp = ggml_graph_n_nodes(gf);
|
||||
}
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
// reserve with tg (token generation) graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
if (!gf) {
|
||||
|
|
|
@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
return moe_out;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * probs,
|
||||
ggml_tensor * up_exps,
|
||||
ggml_tensor * gate_exps,
|
||||
ggml_tensor * down_exps,
|
||||
ggml_tensor * exp_probs_b,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il) const {
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
|
||||
// add experts selection bias - introduced in DeepSeek V3
|
||||
// leave probs unbiased as it's later used to get expert weights
|
||||
ggml_tensor * selection_probs = probs;
|
||||
if (exp_probs_b != nullptr) {
|
||||
selection_probs = ggml_add(ctx0, probs, exp_probs_b);
|
||||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
cb(selected_experts, "ffn_moe_topk", il);
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
||||
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
|
||||
weights = ggml_soft_max(ctx0, weights);
|
||||
} else {
|
||||
weights = ggml_sigmoid(ctx0, weights);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
|
||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
}
|
||||
|
||||
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
||||
|
||||
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
||||
|
||||
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(up, "ffn_moe_up", il);
|
||||
|
||||
ggml_tensor * experts = nullptr;
|
||||
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(cur, "ffn_moe_gate", il);
|
||||
|
||||
cur = ggml_reglu_split(ctx0, cur, up);
|
||||
cb(cur, "ffn_moe_reglu", il);
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
experts = ggml_mul(ctx0, experts, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
|
||||
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
|
||||
|
||||
assert(n_expert_used > 0);
|
||||
|
||||
// order the views before the adds
|
||||
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
|
||||
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
||||
|
||||
ggml_build_forward_expand(gf, cur_experts[i]);
|
||||
}
|
||||
|
||||
// aggregate experts
|
||||
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
|
||||
// to avoid potentially a large number of add nodes during warmup
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
|
||||
ggml_tensor * moe_out = cur_experts[0];
|
||||
|
||||
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
|
||||
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
|
||||
}
|
||||
|
||||
if (n_expert_used == 1) {
|
||||
// avoid returning a non-contiguous tensor
|
||||
moe_out = ggml_cont(ctx0, moe_out);
|
||||
}
|
||||
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
return moe_out;
|
||||
}
|
||||
|
||||
// input embeddings with optional lora
|
||||
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
|
|
@ -625,6 +625,18 @@ struct llm_graph_context {
|
|||
llama_expert_gating_func_type gating_op,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_moe_ffn_from_probs(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * probs,
|
||||
ggml_tensor * up_exps,
|
||||
ggml_tensor * gate_exps,
|
||||
ggml_tensor * down_exps,
|
||||
ggml_tensor * exp_probs_b,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il) const;
|
||||
|
||||
//
|
||||
// inputs
|
||||
//
|
||||
|
|
|
@ -2,9 +2,15 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
|
||||
if (dense_first) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
|
||||
}
|
||||
} else {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ struct llama_hparams {
|
|||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
// llama4
|
||||
// llama4 smallthinker
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
|
@ -161,9 +161,10 @@ struct llama_hparams {
|
|||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||
// dense_first means whether the pattern is start with a dense layer
|
||||
// note that if n_pattern == 0, all layers are SWA
|
||||
// if n_pattern == 1, all layers are dense
|
||||
// example: n_pattern = 3
|
||||
// example 1: n_pattern = 3, dense_first = false
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
|
@ -172,7 +173,13 @@ struct llama_hparams {
|
|||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
void set_swa_pattern(uint32_t n_pattern);
|
||||
// example 2: n_pattern = 2, dense_first = true
|
||||
// il == 0: dense
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// etc ...
|
||||
void set_swa_pattern(uint32_t n_pattern, bool dense_first = false);
|
||||
|
||||
// return true if one of the layers is SWA
|
||||
bool is_swa_any() const;
|
||||
|
|
|
@ -1773,6 +1773,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
if (found_swa && hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa = 4096;
|
||||
hparams.set_swa_pattern(4, true);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
hparams.n_no_rope_layer_step = hparams.n_layer;
|
||||
}
|
||||
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_4B; break;
|
||||
case 52: type = LLM_TYPE_20B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
@ -5261,6 +5284,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
|
||||
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER");
|
||||
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER");
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -5587,6 +5646,11 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_SMALLTHINKER) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
}
|
||||
|
||||
vocab.print_info();
|
||||
}
|
||||
|
||||
|
@ -17111,6 +17175,119 @@ struct llm_build_lfm2 : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_smallthinker : public llm_graph_context{
|
||||
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv_unified();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
ggml_tensor * probs = nullptr;
|
||||
|
||||
probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens]
|
||||
cb(probs, "ffn_moe_logits", il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) {
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
probs = ggml_get_rows(ctx0, probs, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
|
||||
nullptr, n_expert, n_expert_used,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
|
||||
|
||||
cb(ffn_out, "ffn_out", il);
|
||||
cur = ffn_out;
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
|
@ -17549,6 +17726,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_lfm2>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
|
||||
llm = std::make_unique<llm_build_smallthinker<true>> (*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -17747,6 +17932,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
@ -131,6 +131,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
|
||||
PROJECTOR_TYPE_VOXTRAL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -150,6 +151,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
|
||||
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
|
|
@ -372,6 +372,16 @@ struct clip_model {
|
|||
ggml_tensor * conv1d_2_b = nullptr;
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
}
|
||||
|
||||
bool audio_has_stack_frames() const {
|
||||
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
}
|
||||
};
|
||||
|
||||
bool enable_gpu_clip = true;
|
||||
|
@ -1508,49 +1518,52 @@ struct clip_graph {
|
|||
|
||||
cb(cur, "after_transformer", -1);
|
||||
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
if (model.audio_has_stack_frames()) {
|
||||
// StackAudioFrames
|
||||
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||
{
|
||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||
int64_t pad = padded_len - ggml_nelements(cur);
|
||||
if (pad > 0) {
|
||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||
}
|
||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||
ggml_row_size(cur->type, stride), 0);
|
||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||
int64_t pad = padded_len - ggml_nelements(cur);
|
||||
if (pad > 0) {
|
||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||
}
|
||||
|
||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||
ggml_row_size(cur->type, stride), 0);
|
||||
cb(cur, "after_stacked", -1);
|
||||
}
|
||||
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
// UltravoxProjector
|
||||
{
|
||||
// pre-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
// pre-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
|
||||
// ffn in
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
// ffn in
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
|
||||
// swiglu
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
cur = ggml_swiglu_swapped(ctx0, cur);
|
||||
// swiglu
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
cur = ggml_swiglu_swapped(ctx0, cur);
|
||||
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||
|
||||
// ffn out
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
}
|
||||
// ffn out
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
|
||||
// projector
|
||||
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_fc_b);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) {
|
||||
// projector
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
cur = ggml_gelu_erf(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("%s: unknown projector type", __func__);
|
||||
}
|
||||
|
@ -1695,8 +1708,7 @@ private:
|
|||
inpL = cur;
|
||||
}
|
||||
|
||||
// TODO @ngxson : find a way to move this outside
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
|
||||
if (ctx->model.audio_has_avgpool()) {
|
||||
ggml_tensor * cur = inpL;
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
|
@ -2010,6 +2022,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
res = graph.build_llama4();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
res = graph.build_whisper_enc();
|
||||
|
@ -2310,8 +2323,10 @@ struct clip_model_loader {
|
|||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
|
||||
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
|
||||
model.proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||
if (hparams.n_mel_bins != 128) {
|
||||
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
||||
|
@ -2595,6 +2610,15 @@ struct clip_model_loader {
|
|||
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
{
|
||||
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
|
@ -3746,17 +3770,26 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||
int scale_factor = ctx->model.hparams.proj_scale_factor;
|
||||
n_patches_sq /= (scale_factor * scale_factor);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
|
||||
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
||||
n_patches_sq = n_len / proj_stack_factor / 2;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
// divide by 2 because of whisper
|
||||
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
|
||||
n_patches_sq = img->nx / 4;
|
||||
n_patches_sq = img->nx;
|
||||
|
||||
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
|
||||
if (ctx->model.audio_has_stack_frames()) {
|
||||
GGML_ASSERT(proj_stack_factor > 0);
|
||||
const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor);
|
||||
n_patches_sq = n_len / proj_stack_factor;
|
||||
}
|
||||
|
||||
// whisper downscales input token by half after conv1d
|
||||
n_patches_sq /= 2;
|
||||
|
||||
if (ctx->model.audio_has_avgpool()) {
|
||||
// divide by 2 because of nn.AvgPool1d(2, stride=2)
|
||||
n_patches_sq /= 2;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported projector type");
|
||||
|
@ -4162,6 +4195,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
{
|
||||
// do nothing
|
||||
} break;
|
||||
|
@ -4442,6 +4476,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
return ctx->model.projection->ne[1];
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_VOXTRAL:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
return ctx->model.mm_3_w->ne[1];
|
||||
|
@ -4492,7 +4527,8 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
|
|||
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
|
|
|
@ -289,6 +289,10 @@ struct mtmd_context {
|
|||
aud_beg = "<|audio_bos|>";
|
||||
aud_end = "<|audio_eos|>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
// [BEGIN_AUDIO] ... (embeddings) ...
|
||||
aud_beg = "[BEGIN_AUDIO]";
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-r ../../requirements/requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pillow~=10.2.0
|
||||
pillow~=11.3.0
|
||||
torch~=2.2.1
|
||||
torchvision~=0.17.1
|
||||
|
|
|
@ -71,6 +71,7 @@ add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
|||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
|
||||
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
|
|
|
@ -1,18 +1,25 @@
|
|||
# quantize
|
||||
|
||||
This tool takes a GGUF input model file, typically in a high-precision format like F32 or BF16, and converts it to a quantized format.
|
||||
Quantization reduces the precision of model weights (e.g., from 32-bit floats to 4-bit integers), which shrinks the model's size and can speed up inference.
|
||||
This process however, may introduce some accuracy loss which is usually measured in [Perplexity](https://huggingface.co/docs/transformers/en/perplexity) (ppl) and/or [Kullback–Leibler Divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) (kld).
|
||||
This can be minimized by using a suitable imatrix file.
|
||||
|
||||
You can also use the [GGUF-my-repo](https://huggingface.co/spaces/ggml-org/gguf-my-repo) space on Hugging Face to build your own quants without any setup.
|
||||
|
||||
Note: It is synced from llama.cpp `main` every 6 hours.
|
||||
|
||||
Example usage:
|
||||
|
||||
```./llama-quantize [options] input-model-f32.gguf [output-model-quant.gguf] type [threads]```
|
||||
|
||||
```bash
|
||||
# obtain the official LLaMA model weights and place them in ./models
|
||||
# from Hugginface, obtain the official meta-llama/Llama-3.1-8B model weights and place them in ./models
|
||||
ls ./models
|
||||
llama-2-7b tokenizer_checklist.chk tokenizer.model
|
||||
# [Optional] for models using BPE tokenizers
|
||||
ls ./models
|
||||
<folder containing weights and tokenizer json> vocab.json
|
||||
config.json model-00001-of-00004.safetensors model-00004-of-00004.safetensors README.md tokenizer.json
|
||||
generation_config.json model-00002-of-00004.safetensors model.safetensors.index.json special_tokens_map.json USE_POLICY.md
|
||||
LICENSE model-00003-of-00004.safetensors original tokenizer_config.json
|
||||
|
||||
# [Optional] for PyTorch .bin models like Mistral-7B
|
||||
ls ./models
|
||||
<folder containing weights and tokenizer json>
|
||||
|
@ -21,7 +28,7 @@ ls ./models
|
|||
python3 -m pip install -r requirements.txt
|
||||
|
||||
# convert the model to ggml FP16 format
|
||||
python3 convert_hf_to_gguf.py models/mymodel/
|
||||
python3 convert_hf_to_gguf.py ./models/mymodel/
|
||||
|
||||
# quantize the model to 4-bits (using Q4_K_M method)
|
||||
./llama-quantize ./models/mymodel/ggml-model-f16.gguf ./models/mymodel/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
|
@ -37,40 +44,117 @@ Run the quantized model:
|
|||
./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant"
|
||||
```
|
||||
|
||||
When running the larger models, make sure you have enough disk space to store all the intermediate files.
|
||||
Options:
|
||||
* `--allow-requantize` allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit
|
||||
* `--leave-output-tensor` will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing
|
||||
* `--pure` disables k-quant mixtures and quantizes all tensors to the same type
|
||||
* `--imatrix` uses data in file generated by `llama-imatrix` as importance matrix for quant optimizations (highly recommended)
|
||||
* `--include-weights` use an importance matrix for tensor(s) in the list. Cannot be used with `--exclude-weights`
|
||||
* `--exclude-weights` use an importance matrix for tensor(s) in the list. Cannot be used with `--include-weights`
|
||||
* `--output-tensor-type` use a specific quant type for the output.weight tensor
|
||||
* `--token-embedding-type` use a specific quant type for the token embeddings tensor
|
||||
* `--keep-split` will generate the quantized model in the same shards as the input file otherwise it will produce a single quantized file
|
||||
|
||||
Advanced options:
|
||||
* `--tensor-type` quantize specific tensor(s) to specific quant types. Supports regex syntax. May be specified multiple times.
|
||||
* `--prune-layers` prune (remove) the layers in the list
|
||||
* `--override-kv` option to override model metadata by key in the quantized model. May be specified multiple times
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# naive Q4_K_M quantization using default settings and 8 CPU threads. Output will be "ggml-model-Q4_K_M.gguf"
|
||||
./llama-quantize input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# quantize model enabling re-quantization, leaving the output tensor unquantized and all others quantized at the same level (Q4_K)
|
||||
./llama-quantize --allow-requantize --leave-output-tensor --pure input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# quantize model using an importance matrix for specified tensors only (attn_v and ffn_down)
|
||||
./llama-quantize --imatrix imatrix.gguf --include-weights attn_v --include-weights ffn_down input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# quantize model setting output tensor to Q5_K_M, token embeddings to Q3_K_M, and keeping the input file's shards
|
||||
./llama-quantize --imatrix imatrix.gguf --output-tensor-type q5_k --token-embedding-type q3_k --keep-split input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# quantize model using a regex to quantize attn_k tensors in odd layers to Q5_K_M and attn_q tensors in even layers to Q3_K_M
|
||||
./llama-quantize --imatrix imatrix.gguf --tensor-type "\.(\d*[13579])\.attn_k=q5_k" --tensor-type "\.(\d*[02468])\.attn_q=q3_k" input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# quantize model setting tensors attn_v and ffn_down to Q5_K_M and pruning layers 20, 21, and 22
|
||||
./llama-quantize --imatrix imatrix.gguf --tensor-type attn_v=q5_k --tensor-type ffn_down=q5_k --prune-layers 20,21,22 input-model-f32.gguf q4_k_m 8
|
||||
```
|
||||
|
||||
```bash
|
||||
# override expert used count metadata to 16, prune layers 20, 21, and 22 without quantizing the model (copy tensors) and use specified name for the output file
|
||||
./llama-quantize --imatrix imatrix.gguf --override-kv qwen3moe.expert_used_count=int:16 --prune-layers 20,21,22 input-model-f32.gguf pruned-model-f32.gguf copy 8
|
||||
```
|
||||
|
||||
## Memory/Disk Requirements
|
||||
|
||||
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same.
|
||||
When running the larger models, make sure you have enough disk space to store all the intermediate files.
|
||||
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. For exmaple (Llama 3.1):
|
||||
|
||||
| Model | Original size | Quantized size (Q4_K_M) |
|
||||
| ----: | ------------: | ----------------------: |
|
||||
| 8B | 32.1 GB | 4.9 GB |
|
||||
| 70B | 280.9 GB | 43.1 GB |
|
||||
| 405B | 1,625.1 GB | 249.1 GB |
|
||||
|
||||
| Model | Original size | Quantized size (Q4_0) |
|
||||
|------:|--------------:|----------------------:|
|
||||
| 7B | 13 GB | 3.9 GB |
|
||||
| 13B | 24 GB | 7.8 GB |
|
||||
| 30B | 60 GB | 19.5 GB |
|
||||
| 65B | 120 GB | 38.5 GB |
|
||||
|
||||
## Quantization
|
||||
|
||||
Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
|
||||
Several quantization methods are supported. They differ in the resulting model disk size and inference speed. For example,
|
||||
|
||||
*(outdated)*
|
||||
### [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B)
|
||||
|
||||
| Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 |
|
||||
|------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:|
|
||||
| 7B | perplexity | 5.9066 | 6.1565 | 6.0912 | 5.9862 | 5.9481 | 5.9070 |
|
||||
| 7B | file size | 13.0G | 3.5G | 3.9G | 4.3G | 4.7G | 6.7G |
|
||||
| 7B | ms/tok @ 4th | 127 | 55 | 54 | 76 | 83 | 72 |
|
||||
| 7B | ms/tok @ 8th | 122 | 43 | 45 | 52 | 56 | 67 |
|
||||
| 7B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 |
|
||||
| 13B | perplexity | 5.2543 | 5.3860 | 5.3608 | 5.2856 | 5.2706 | 5.2548 |
|
||||
| 13B | file size | 25.0G | 6.8G | 7.6G | 8.3G | 9.1G | 13G |
|
||||
| 13B | ms/tok @ 4th | - | 103 | 105 | 148 | 160 | 131 |
|
||||
| 13B | ms/tok @ 8th | - | 73 | 82 | 98 | 105 | 128 |
|
||||
| 13B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 |
|
||||
| Measure | IQ1_S | IQ1_M | IQ2_XXS | IQ2_XS | IQ2_S | IQ2_M |
|
||||
| --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ |
|
||||
| bits/weight | 2.0042 | 2.1460 | 2.3824 | 2.5882 | 2.7403 | 2.9294 |
|
||||
| size (GiB) | 1.87 | 2.01 | 2.23 | 2.42 | 2.56 | 2.74 |
|
||||
| prompt processing t/s @ 512 | 858.88 ±1.22 | 847.99 ±0.47 | 852.39 ±0.85 | 826.99 ±12.51 | 783.55 ±13.73 | 787.68 ±7.00 |
|
||||
| text generation t/s @ 128 | 79.73 ±0.79 | 72.92 ±0.14 | 79.86 ±0.22 | 78.04 ±0.46 | 77.30 ±2.47 | 74.44 ±0.15 |
|
||||
|
||||
| Measure | IQ3_XXS | IQ3_XS | IQ3_S | IQ3_M | IQ4_XS | IQ4_NL |
|
||||
| --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ |
|
||||
| bits/weight | 3.2548 | 3.4977 | 3.6606 | 3.7628 | 4.4597 | 4.6818 |
|
||||
| size (GiB) | 3.04 | 3.27 | 3.42 | 3.52 | 4.17 | 4.38 |
|
||||
| prompt processing t/s @ 512 | 813.88 ±6.53 | 708.71 ±1.26 | 798.78 ±8.81 | 768.70 ±13.73 | 771.80 ±11.38 | 806.03 ±7.07 |
|
||||
| text generation t/s @ 128 | 73.95 ±0.20 | 71.67 ±0.54 | 69.31 ±0.63 | 70.15 ±0.33 | 77.51 ±0.20 | 76.63 ±0.28 |
|
||||
|
||||
|
||||
| Measure | Q2_K_S | Q2_K | Q3_K_S | Q3_K_M | Q3_K_L | Q4_K_S |
|
||||
| --------------------------- | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ |
|
||||
| bits/weight | 2.9697 | 3.1593 | 3.6429 | 3.9960 | 4.2979 | 4.6672 |
|
||||
| size (GiB) | 2.78 | 2.95 | 3.41 | 3.74 | 4.02 | 4.36 |
|
||||
| prompt processing t/s @ 512 | 798.91 ±6.40 | 784.45 ±7.85 | 752.17 ±7.94 | 783.44 ±9.92 | 761.17 ±7.55 | 818.55 ±9.58 |
|
||||
| text generation t/s @ 128 | 90.01 ±0.12 | 79.85 ±0.20 | 69.84 ±0.18 | 71.68 ±0.22 | 69.38 ±0.49 | 76.71 ±0.20 |
|
||||
|
||||
| Measure | Q4_K_S | Q4_K_M | Q5_K_S | Q5_K_M | Q6_K | Q8_0 |
|
||||
| --------------------------- | ------------ | ------------- | ------------ | ------------ | ------------- | ------------ |
|
||||
| bits/weight | 4.6672 | 4.8944 | 5.5704 | 5.7036 | 6.5633 | 8.5008 |
|
||||
| size (GiB) | 4.36 | 4.58 | 5.21 | 5.33 | 6.14 | 7.95 |
|
||||
| prompt processing t/s @ 512 | 818.55 ±9.58 | 821.81 ±21.44 | 752.52 ±0.99 | 758.69 ±7.43 | 812.01 ±10.82 | 865.09 ±8.30 |
|
||||
| text generation t/s @ 128 | 76.71 ±0.20 | 71.93 ±1.52 | 69.53 ±0.18 | 67.23 ±1.08 | 58.67 ±3.13 | 50.93 ±0.08 |
|
||||
|
||||
| Measure | F16 |
|
||||
| --------------------------- | ------------ |
|
||||
| bits/weight | 16.0005 |
|
||||
| size (GiB) | 14.96 |
|
||||
| prompt processing t/s @ 512 | 923.49 ±0.53 |
|
||||
| text generation t/s @ 128 | 29.17 ±0.04 |
|
||||
|
||||
## Background information on llama-quantize
|
||||
|
||||
- [k-quants](https://github.com/ggml-org/llama.cpp/pull/1684)
|
||||
- recent k-quants improvements and new i-quants
|
||||
- k-quants improvements and i-quants
|
||||
- [#2707](https://github.com/ggml-org/llama.cpp/pull/2707)
|
||||
- [#2807](https://github.com/ggml-org/llama.cpp/pull/2807)
|
||||
- [#4773 - 2-bit i-quants (inference)](https://github.com/ggml-org/llama.cpp/pull/4773)
|
||||
|
@ -85,45 +169,3 @@ Several quantization methods are supported. They differ in the resulting model d
|
|||
- [#5060 - Q3_K_XS](https://github.com/ggml-org/llama.cpp/pull/5060)
|
||||
- [#5196 - 3-bit i-quants](https://github.com/ggml-org/llama.cpp/pull/5196)
|
||||
- [quantization tuning](https://github.com/ggml-org/llama.cpp/pull/5320), [another one](https://github.com/ggml-org/llama.cpp/pull/5334), and [another one](https://github.com/ggml-org/llama.cpp/pull/5361)
|
||||
|
||||
**Llama 2 7B**
|
||||
|
||||
| Quantization | Bits per Weight (BPW) |
|
||||
|--------------|-----------------------|
|
||||
| Q2_K | 3.35 |
|
||||
| Q3_K_S | 3.50 |
|
||||
| Q3_K_M | 3.91 |
|
||||
| Q3_K_L | 4.27 |
|
||||
| Q4_K_S | 4.58 |
|
||||
| Q4_K_M | 4.84 |
|
||||
| Q5_K_S | 5.52 |
|
||||
| Q5_K_M | 5.68 |
|
||||
| Q6_K | 6.56 |
|
||||
|
||||
**Llama 2 13B**
|
||||
|
||||
Quantization | Bits per Weight (BPW)
|
||||
-- | --
|
||||
Q2_K | 3.34
|
||||
Q3_K_S | 3.48
|
||||
Q3_K_M | 3.89
|
||||
Q3_K_L | 4.26
|
||||
Q4_K_S | 4.56
|
||||
Q4_K_M | 4.83
|
||||
Q5_K_S | 5.51
|
||||
Q5_K_M | 5.67
|
||||
Q6_K | 6.56
|
||||
|
||||
**Llama 2 70B**
|
||||
|
||||
Quantization | Bits per Weight (BPW)
|
||||
-- | --
|
||||
Q2_K | 3.40
|
||||
Q3_K_S | 3.47
|
||||
Q3_K_M | 3.85
|
||||
Q3_K_L | 4.19
|
||||
Q4_K_S | 4.53
|
||||
Q4_K_M | 4.80
|
||||
Q5_K_S | 5.50
|
||||
Q5_K_M | 5.65
|
||||
Q6_K | 6.56
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue