merged pixtral support, not fully working

This commit is contained in:
Concedo 2025-04-24 15:27:02 +08:00
commit 28a2723100
24 changed files with 1279 additions and 512 deletions

View file

@ -776,6 +776,9 @@ class TextModel(ModelBase):
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
# ref: https://huggingface.co/mistral-community/pixtral-12b
res = "pixtral"
if res is None:
logger.warning("\n")
@ -1724,7 +1727,8 @@ class StableLMModel(TextModel):
"MistralForCausalLM",
"MixtralForCausalLM",
"Idefics3ForConditionalGeneration",
"SmolVLMForConditionalGeneration")
"SmolVLMForConditionalGeneration",
"LlavaForConditionalGeneration")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
@ -1734,6 +1738,10 @@ class LlamaModel(TextModel):
# fix for SmolVLM2, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
# fix for Pixtral, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
and self.hparams.get("model_type") == "mistral":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self):
try:
@ -1797,12 +1805,17 @@ class LlamaModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
is_vision_tensor = "vision_tower" in name \
or "vision_model" in name \
or "model.connector" in name \
or "multi_modal_projector" in name
if is_vision_tensor:
return [] # skip vision tensors
elif name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
elif name.startswith("language_model."):
name = name.replace("language_model.", "") # for the rest
if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
@ -1885,6 +1898,55 @@ class LlamaModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("LlavaForConditionalGeneration")
class LlavaVisionModel(VisionModel):
img_break_tok_id = -1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral":
# fix missing config.json values
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
# default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
n_head = self.hparams["num_attention_heads"]
n_kv_head = n_head
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
# process vision tensors
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
img_break_embd = data_torch[self.img_break_tok_id]
name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
return [(self.map_tensor_name(name), img_break_embd)]
return [] # skip other tensors
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs):
@ -5079,10 +5141,25 @@ class Glm4Model(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4
def set_vocab(self):
self._set_vocab_gpt2()
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"])
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_dim = self.hparams["head_dim"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)

View file

@ -115,6 +115,7 @@ models = [
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
]

View file

@ -11,15 +11,15 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
cmake --build build --target llama-mtmd-cli
# alternatively, install from brew (MacOS)
brew install llama.cpp
# run it
llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
# note: 1B model does not support vision
```
@ -44,8 +44,8 @@ What you need:
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
cmake --build build --target llama-mtmd-cli
# run it
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
```

View file

@ -64,6 +64,7 @@
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
@ -78,6 +79,7 @@
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
// mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@ -106,6 +108,7 @@ enum projector_type {
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_UNKNOWN,
};
@ -118,6 +121,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View file

@ -177,7 +177,8 @@ struct clip_hparams {
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
float eps;
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
@ -203,11 +204,17 @@ struct clip_layer {
struct ggml_tensor * ln_1_b = nullptr;
// ff
struct ggml_tensor * ff_i_w = nullptr;
struct ggml_tensor * ff_i_b = nullptr;
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr;
struct ggml_tensor * ff_o_b = nullptr;
struct ggml_tensor * ff_up_w = nullptr;
struct ggml_tensor * ff_up_b = nullptr;
struct ggml_tensor * ff_gate_w = nullptr;
struct ggml_tensor * ff_gate_b = nullptr;
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
struct ggml_tensor * ff_g_w = NULL;
struct ggml_tensor * ff_g_b = NULL;
@ -316,6 +323,9 @@ struct clip_vision_model {
// gemma3
struct ggml_tensor * mm_input_proj_w = nullptr;
struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
};
bool enable_gpu_clip = true;
@ -357,6 +367,7 @@ struct clip_ctx {
ggml_backend_buffer_ptr buf;
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_image_size load_image_size;
@ -577,6 +588,218 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
return gf;
}
// implementation of the 2D RoPE without adding a new op in ggml
static ggml_tensor * build_rope_2d(
ggml_cgraph * gf,
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
) {
ggml_tensor * tmp;
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
// for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
// first half of cur will use 1e-0, 1e-2 (even)
// second half of cur will use 1e-1, 1e-3 (odd)
//
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
// first half
{
cur = ggml_rope_ext_inplace(
ctx0,
cur,
pos_h, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);
}
// second half
{
tmp = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
n_dim/2 * ggml_element_size(cur));
tmp = ggml_rope_ext_inplace(
ctx0,
tmp,
pos_w, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
freq_scale_odd,
0.0f, 1.0f, 0.0f, 0.0f
);
// calculate inplace (modify cur directly)
ggml_build_forward_expand(gf, tmp);
}
return cur;
}
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
int image_size_width = imgs.entries[0]->nx;
int image_size_height = imgs.entries[0]->ny;
const int patch_size = hparams.patch_size;
const int n_patches_x = image_size_width / patch_size;
const int n_patches_y = image_size_height / patch_size;
const int num_patches = n_patches_x * n_patches_y;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx0_ptr(ggml_init(params));
auto ctx0 = ctx0_ptr.get();
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
// 2D input positions
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
struct ggml_tensor * embeddings = inp;
// pre-layer norm
embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings;
// pre-attention norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
// self-attention
{
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// pre-ffn norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
// feed-forward
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// LlavaMultiModalProjector (with GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
// arrangement of the [IMG_BREAK] token
{
// not efficient, but works
// the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
embeddings = ggml_view_2d(ctx0, cur,
n_embd_text, n_tokens_output,
ggml_row_size(cur->type, n_embd_text), 0);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
@ -1234,6 +1457,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = clip_image_build_graph_siglip(ctx, imgs);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
res = clip_image_build_graph_pixtral(ctx, imgs);
} break;
default:
{
// TODO: we should have one build_* function per model
@ -1402,6 +1629,10 @@ struct clip_model_loader {
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
} break;
default:
break;
}
@ -1473,18 +1704,28 @@ struct clip_model_loader {
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_g_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), ctx_clip.use_glu_mlp);
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), !ctx_clip.use_rms_norm);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), !ctx_clip.use_rms_norm);
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_g_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), ctx_clip.use_glu_mlp);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
// new naming
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
// legacy naming (the in and out is reversed! don't ask me why)
layer.ff_i_w = layer.ff_down_w;
layer.ff_o_w = layer.ff_up_w;
layer.ff_i_b = layer.ff_down_b;
layer.ff_o_b = layer.ff_up_b;
layer.ff_g_w = layer.ff_gate_w;
layer.ff_g_b = layer.ff_gate_b;
}
switch (ctx_clip.proj_type) {
@ -1600,6 +1841,15 @@ struct clip_model_loader {
{
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@ -1642,18 +1892,17 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
clip_image_size image_size;
image_size.width = clip_get_image_size(&ctx_clip);
image_size.height = clip_get_image_size(&ctx_clip);
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
img->nx = n_patches;
img->ny = n_patches;
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
image_size.width = ctx_clip.vision_model.hparams.image_size;
image_size.height = ctx_clip.vision_model.hparams.image_size;
img->nx = image_size.width;
img->ny = image_size.height;
img->buf.resize(image_size.width * image_size.height * 3);
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
@ -2128,6 +2377,26 @@ struct image_manipulation {
}
}
// calculate the size of the **resized** image, while preserving the aspect ratio
// the calculated size will be aligned to the nearest multiple of align_size
// if H or W size is larger than max_dimension, it will be resized to max_dimension
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
return {0, 0};
}
float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
static_cast<float>(max_dimension) / inp_size.height));
float target_width_f = static_cast<float>(inp_size.width) * scale;
float target_height_f = static_cast<float>(inp_size.height) * scale;
int aligned_width = GGML_PAD((int)target_width_f, align_size);
int aligned_height = GGML_PAD((int)target_height_f, align_size);
return {aligned_width, aligned_height};
}
private:
static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper));
@ -2459,8 +2728,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
if (ctx->has_glm_projector
else if (ctx->has_glm_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
@ -2472,6 +2740,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
@ -2603,6 +2880,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
n_patches = 256;
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size;
int n_patches_y = img->ny / params.patch_size;
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
return n_patches;
@ -2756,10 +3037,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
// TODO @ngxson : this whole code block is ugly, will need to be refactored
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
if (ctx->has_glm_projector
|| ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
@ -2937,6 +3223,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
// do nothing
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
std::vector<int> pos_data(num_positions);
struct ggml_tensor * pos;
// dimension H
pos = ggml_graph_get_tensor(gf, "pos_h");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i / n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
// dimension W
pos = ggml_graph_get_tensor(gf, "pos_w");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i % n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@ -3204,6 +3508,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_LDPV2:
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];

View file

@ -24,7 +24,9 @@
#include <signal.h>
#endif
static bool g_is_generating = false;
// volatile, because of signal being an interrupt
static volatile bool g_is_generating = false;
static volatile bool g_is_interrupted = false;
/**
* Please note that this is NOT a production-ready stuff.
@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
g_is_generating = false;
} else {
console::cleanup();
LOG("\nInterrupted by user\n");
_exit(130);
if (g_is_interrupted) {
_exit(1);
}
g_is_interrupted = true;
}
}
}
@ -167,7 +171,7 @@ struct decode_embd_batch {
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating) {
if (i > n_predict || !g_is_generating || g_is_interrupted) {
printf("\n");
break;
}
@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);
if (g_is_interrupted) {
printf("\n");
break;
}
// eval the token
common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
text.add_special = add_bos;
text.parse_special = true;
mtmd_input_chunks chunks;
if (g_is_interrupted) return 0;
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
if (res != 0) {
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
#endif
}
if (g_is_interrupted) return 130;
if (is_single_turn) {
g_is_generating = true;
if (params.prompt.find("<__image__>") == std::string::npos) {
@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
if (eval_message(ctx, msg, params.image, true)) {
return 1;
}
if (generate_response(ctx, smpl, n_predict)) {
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
return 1;
}
@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
std::vector<std::string> images_fname;
std::string content;
while (true) {
while (!g_is_interrupted) {
g_is_generating = false;
LOG("\n> ");
console::set_display(console::user_input);
std::string line;
console::readline(line, false);
if (g_is_interrupted) break;
console::set_display(console::reset);
line = string_strip(line);
if (line.empty()) {
@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
msg.role = "user";
msg.content = content;
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
if (g_is_interrupted) break;
if (ret == 2) {
// non-fatal error
images_fname.clear();
@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
is_first_msg = false;
}
}
if (g_is_interrupted) LOG("\nInterrupted by user\n");
llama_perf_context_print(ctx.lctx);
return 0;
return g_is_interrupted ? 130 : 0;
}

View file

@ -190,6 +190,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
}
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
@ -219,7 +224,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
for (auto & entry : batch_f32.entries) {
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_patches(ctx->ctx_clip);
image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
image_tokens->ny = 1;
image_tokens->batch_f32.entries.push_back(std::move(entry));
image_tokens->id = id;
@ -313,8 +318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
}
} else {
size_t n_tokens = 0;
for (const auto & entry : batch_f32.entries) {
n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
}
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
image_tokens->nx = n_tokens;
image_tokens->ny = 1; // TODO
image_tokens->batch_f32 = std::move(batch_f32);
image_tokens->id = bitmaps[i_img].id; // optional
@ -382,7 +392,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) {
int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
ok = clip_image_encode(
ctx->ctx_clip,
ctx->n_threads,

View file

@ -13,6 +13,14 @@ mkdir -p $SCRIPT_DIR/output
PROJ_ROOT="$SCRIPT_DIR/../.."
cd $PROJ_ROOT
# Check if the first argument is "big", then run test with big models
# This is useful if we're running the script on a larger machine, so we can test the big models
RUN_BIG_TESTS=false
if [ "${1:-}" = "big" ]; then
RUN_BIG_TESTS=true
echo "Include BIG models..."
fi
###############
arr_bin=()
@ -28,6 +36,12 @@ add_test() {
arr_tmpl+=("$tmpl")
}
add_test_big() {
if [ "$RUN_BIG_TESTS" = true ]; then
add_test "$@"
fi
}
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@ -42,6 +56,9 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
# these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"

View file

@ -1411,6 +1411,11 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
// const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
const int64_t nb13 = src1->nb[3];
const int64_t nb2 = dst->nb[2];
const int64_t nb3 = dst->nb[3];
@ -1546,7 +1551,10 @@ static void ggml_cuda_op_mul_mat(
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
if (src1_on_device && src1_is_contiguous) {
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
quantize_src1(
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
src1_padded_col_size, ne11, ne12, ne13, stream);
CUDA_CHECK(cudaGetLastError());
}
}
@ -1641,7 +1649,9 @@ static void ggml_cuda_op_mul_mat(
}
if (quantize_src1 && !src1_is_contiguous) {
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
quantize_src1(
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
src1_padded_col_size, src1_ncols, 1, 1, stream);
CUDA_CHECK(cudaGetLastError());
}
@ -1879,7 +1889,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@ -1920,10 +1930,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
} else if (!split && use_mul_mat_vec_q) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
@ -2000,6 +2012,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
GGML_TENSOR_BINARY_OP_LOCALS
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
} else {
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
}
return;
}
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
@ -2036,27 +2057,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@ -2128,7 +2128,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
CUDA_CHECK(cudaGetLastError());
}
}
}
}
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
@ -2494,7 +2493,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID) {
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
@ -3208,9 +3207,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK: {
const size_t ts = ggml_type_size(op->src[0]->type);
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:

View file

@ -4,18 +4,23 @@
template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.y;
const int64_t sample = blockIdx.z;
const int64_t channel_dst = blockIdx.y;
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int64_t sample_dst = blockIdx.z;
const int64_t sample_x = sample_dst / sample_ratio;
const int64_t sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += sample *stride_sample_y + channel *stride_channel_y;
dst += sample *stride_sample_dst + channel *stride_channel_dst;
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
@ -31,12 +36,19 @@ static __global__ void mul_mat_vec(
float sumf = 0.0f;
if constexpr (std::is_same<T, half>::value) {
if constexpr (std::is_same<T, float>::value) {
const float2 * x2 = (const float2 *) x;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2];
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x;
sumf += tmpx.y*tmpy.y;
}
} else if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
const float2 tmpy = y2[col2];
@ -59,8 +71,6 @@ static __global__ void mul_mat_vec(
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
const int * x2 = (const int *) x;
sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
const float2 tmpy = y2[col2];
@ -92,17 +102,17 @@ static __global__ void mul_mat_vec(
template <typename T, typename type_acc>
static void launch_mul_mat_vec_cuda(
const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0);
GGML_ASSERT(nsamples_y % nsamples_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x;
const int64_t sample_ratio = nsamples_y / nsamples_x;
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda(
}
const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_y, nsamples_y);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda(
template<typename T>
static void mul_mat_vec_cuda(
const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
switch (prec) {
case GGML_PREC_DEFAULT: {
if constexpr(std::is_same<T, half>::value) {
if (prec == GGML_PREC_DEFAULT) {
launch_mul_mat_vec_cuda<T, half>
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case GGML_PREC_F32: {
launch_mul_mat_vec_cuda<T, float>
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
launch_mul_mat_vec_cuda<T, float>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne11 == 1);
GGML_ASSERT(ne12 == ne2);
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT(nb00 == ts_src0);
GGML_ASSERT(nb10 == ts_src1);
GGML_ASSERT(nb0 == ts_dst);
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec(
const int64_t stride_row = ne00;
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1;
const int64_t stride_channel_x = 0;
const int64_t stride_channel_y = 0;
const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
const int64_t nsamples_y = 1;
const int64_t nsamples_dst = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));

View file

@ -3,7 +3,7 @@
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
#define MMV_MAX_ROWS 512
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec(
ggml_backend_cuda_context & ctx,

View file

@ -1,6 +1,9 @@
#include "mmvq.cuh"
#include "quantize.cuh"
#include "vecdotq.cuh"
#include <cstdint>
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
@ -73,9 +76,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_y) {
switch (ncols_dst) {
case 1:
case 2:
case 3:
@ -90,7 +93,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete
return 1;
}
} else if (table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
switch (ncols_dst) {
case 1:
case 2:
case 3:
@ -107,9 +110,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
switch (ncols_dst) {
case 1:
return 1;
case 2:
@ -127,19 +130,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta
return 1;
}
template <ggml_type type, int ncols_y>
template <ggml_type type, int ncols_dst>
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@ -147,13 +152,21 @@ static __global__ void mul_mat_vec_q(
const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
// The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
const int channel_dst = blockIdx.y;
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
const block_q8_1 * y = (const block_q8_1 *) vy;
// partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
@ -162,18 +175,19 @@ static __global__ void mul_mat_vec_q(
const int kqs = vdr * (tid % (qi/vdr));
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
tmp[j][i] += vec_dot_q_cuda(
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
@ -185,9 +199,11 @@ static __global__ void mul_mat_vec_q(
return;
}
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
#pragma unroll
@ -197,88 +213,121 @@ static __global__ void mul_mat_vec_q(
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
}
}
GGML_UNUSED(nrows_x);
}
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
template <ggml_type type>
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const int channel_ratio = nchannels_dst / nchannels_x;
const int sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
switch (ncols_y) {
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1:
{
constexpr int c_ncols_y = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 2:
{
constexpr int c_ncols_y = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 3:
{
constexpr int c_ncols_y = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 4:
{
constexpr int c_ncols_y = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 5:
{
constexpr int c_ncols_y = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 6:
{
constexpr int c_ncols_y = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 7:
{
constexpr int c_ncols_y = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 8:
{
constexpr int c_ncols_y = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
default:
@ -287,137 +336,213 @@ static void mul_mat_vec_q_cuda(
}
}
static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
static void mul_mat_vec_q_switch_type(
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
static void mul_mat_vec_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
void ggml_cuda_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
GGML_TENSOR_BINARY_OP_LOCALS;
static void mul_mat_vec_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
cudaStream_t stream = ctx.stream();
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
static void mul_mat_vec_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
GGML_ASSERT( nb00 == ts_src0);
GGML_ASSERT( nb10 == ts_src1);
GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data;
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
{
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
}
static void mul_mat_vec_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = ne10_padded / QK8_1;
const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
const int64_t s12 = ne11*s11;
const int64_t s13 = ne12*s12;
static void mul_mat_vec_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq1_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq1_m_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
@ -440,68 +565,12 @@ void ggml_cuda_op_mul_mat_vec_q(
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
const int stride_col_y = src1_padded_row_size / QK8_1;
mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);

View file

@ -2,6 +2,9 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,

View file

@ -1,23 +1,33 @@
#include "quantize.cuh"
#include <cstdint>
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (ix0 >= kx0_padded) {
if (i0 >= ne0) {
return;
}
const int64_t ix1 = blockIdx.y;
const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i_padded = ix1*kx0_padded + ix0;
const int64_t & i00 = i0;
const int64_t & i01 = i1;
const int64_t & i02 = i2;
const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy;
const int64_t ib = i_padded / QK8_1; // block index
const int64_t iqs = i_padded % QK8_1; // quant index
const int64_t ib = i_cont / QK8_1; // block index
const int64_t iqs = i_cont % QK8_1; // quant index
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
@ -127,43 +137,45 @@ static __global__ void quantize_mmq_q8_1(
}
void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % QK8_1 == 0);
GGML_ASSERT(ne0 % QK8_1 == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, kx1*channels, 1);
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
GGML_UNUSED(type_x);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
GGML_UNUSED(type_src0);
}
void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels);
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_x)) {
switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
break;
default:
GGML_ABORT("fatal error");
break;
}
GGML_UNUSED(s01);
GGML_UNUSED(s02);
GGML_UNUSED(s03);
}

View file

@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
const ggml_type type_x, cudaStream_t stream);
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
const ggml_type type_x, cudaStream_t stream);
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
const ggml_type type_x, cudaStream_t stream);
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);

View file

@ -1,3 +1,5 @@
#pragma once
#include "common.cuh"
#include <cstdint>

View file

@ -485,6 +485,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_OUTPUT = auto()
V_ENC_OUTPUT_NORM = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
@ -501,6 +502,7 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_Q_NORM = auto() # minicpmv
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -737,6 +739,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
@ -753,6 +756,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -771,6 +775,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
@ -787,6 +792,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_Q_NORM,
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
@ -2129,6 +2135,7 @@ class GGUFValueType(IntEnum):
class VisionProjectorType:
GEMMA3 = "gemma3"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
# Items here are (block size, type size)

View file

@ -914,6 +914,7 @@ class TensorNameMap:
"vision_tower.vision_model.embeddings.patch_embedding",
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
"vision_tower.patch_conv", # pixtral
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
@ -926,52 +927,65 @@ class TensorNameMap:
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
),
MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
),
MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
),
MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
),
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
),
MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
),
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral
),
MODEL_TENSOR.V_POST_NORM: (
@ -1030,6 +1044,10 @@ class TensorNameMap:
MODEL_TENSOR.V_RESMPL_QUERY: (
"resampler.query",
),
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
"v.token_embd.img_break", # for pixtral, this is a generated vector
),
}
# architecture-specific block mappings

View file

@ -2890,7 +2890,8 @@ static void PrepareLlavaEmbds(const int nctx, const std::vector<int> & llava_sep
{
printf("\nLLAVA Clip Embed %i used Tokens: %d",i,llava_images[i].clp_image_tokens);
}
if(llava_images[i].clp_image_tokens>0 && llava_images[i].clp_image_tokens < nctx)
int cliptokensneeded = llava_images[i].clp_image_tokens;
if(cliptokensneeded>0 && cliptokensneeded < nctx)
{
int tokcnt = (i==0?(llava_images[i].clp_image_tokens):(llava_images[i].clp_image_tokens+sepsize));
if(i==0)
@ -2904,7 +2905,7 @@ static void PrepareLlavaEmbds(const int nctx, const std::vector<int> & llava_sep
}
else
{
printf("\nWarning: LLAVA Image excluded - Context size too low or not enough clip tokens!\n");
printf("\nWarning: LLAVA Image excluded - Context size too low or not enough clip tokens! (needed %d)\n",cliptokensneeded);
}
}
}

View file

@ -113,6 +113,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
};
enum llama_rope_type {

View file

@ -2945,6 +2945,7 @@ Current version indicated by LITEVER below.
const default_claude_base = "https://api.anthropic.com";
const default_gemini_base = "https://generativelanguage.googleapis.com/v1beta/models/";
const default_gemini_suffix = ":generateContent?key=";
const default_gemini_stream_suffix = ":streamGenerateContent?alt=sse&key=";
const default_cohere_base = "https://api.cohere.ai/v1/chat";
const a1111_models_endpoint = "/sdapi/v1/sd-models";
@ -5075,6 +5076,43 @@ Current version indicated by LITEVER below.
render_gametext();
});
}
function gemini_api_sync_req(targetep, payload, geminiheaders)
{
fetch(targetep, {
method: 'POST',
headers: geminiheaders,
body: JSON.stringify(payload),
referrerPolicy: 'no-referrer',
})
.then((response) => response.json())
.then((data) => {
console.log("sync finished response: " + JSON.stringify(data));
last_response_obj = JSON.parse(JSON.stringify(data));
if (custom_gemini_key != "" && data.candidates != null && data.candidates.length>0 && data.candidates[0].output && data.candidates[0].output != "") {
synchro_polled_response = data.candidates[0].output;
}else if (custom_gemini_key != "" && data.candidates != null && data.candidates.length>0 && data.candidates[0].content && data.candidates[0].content.parts != null && data.candidates[0].content.parts.length>0) {
synchro_polled_response = data.candidates[0].content.parts[0].text;
//try to handle the stripping of spaces
if(localsettings.opmode==1 && gametext_arr.length>0 && synchro_polled_response!="")
{
synchro_polled_response = cleanup_story_completion(synchro_polled_response);
}
}
else {
//error occurred, maybe captcha failed
console.error("error occurred in Gemini generation");
clear_poll_flags();
render_gametext();
msgbox("Error occurred during text generation: " + format_json_error(data));
}
})
.catch((error) => {
console.error('Error:', error);
clear_poll_flags();
render_gametext();
msgbox("Error while submitting prompt: " + error);
});
}
function oai_api_sync_req(targetep,oai_payload,oaiheaders)
{
fetch(targetep, {
@ -5181,6 +5219,16 @@ Current version indicated by LITEVER below.
if(pending_response_id && pending_response_id != "-1" && pending_response_id != "")
{
for (let event of chunk) {
//for gemini
if (event.data && event.data.candidates && event.data.candidates.length>0) {
if(event.data.candidates[0].content && event.data.candidates[0].content.parts && event.data.candidates[0].content.parts.length>0 && event.data.candidates[0].content.parts[0].text)
{
synchro_pending_stream += event.data.candidates[0].content.parts[0].text;
}
continue;
}
//for oai
if (event.data && event.data.choices && event.data.choices.length>0) {
if(event.data.choices[0].text)
{
@ -15025,7 +15073,6 @@ Current version indicated by LITEVER below.
else if (custom_gemini_key != "")//handle for Gemini
{
let mdlname = document.getElementById("custom_gemini_model").value;
let urlbase = default_gemini_base + mdlname + default_gemini_suffix;
let geminitopk = (submit_payload.params.top_k<1?100:submit_payload.params.top_k);
geminitopk = geminitopk>100?100:geminitopk;
if(mdlname.includes("flash-8b"))
@ -15039,7 +15086,7 @@ Current version indicated by LITEVER below.
submit_payload.params.max_length += 100; //add length
}
urlbase = default_gemini_base + mdlname + default_gemini_suffix;
let geminiparts = [];
if (insertAIVisionImages.length > 0) {
@ -15134,47 +15181,21 @@ Current version indicated by LITEVER below.
};
}
let targetep = urlbase + custom_gemini_key;
last_request_str = JSON.stringify(payload);
last_response_obj = null;
fetch(targetep, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(payload),
referrerPolicy: 'no-referrer',
})
.then((response) => response.json())
.then((data) => {
console.log("sync finished response: " + JSON.stringify(data));
last_response_obj = JSON.parse(JSON.stringify(data));
if (custom_gemini_key != "" && data.candidates != null && data.candidates.length>0 && data.candidates[0].output && data.candidates[0].output != "") {
synchro_polled_response = data.candidates[0].output;
}else if (custom_gemini_key != "" && data.candidates != null && data.candidates.length>0 && data.candidates[0].content && data.candidates[0].content.parts != null && data.candidates[0].content.parts.length>0) {
synchro_polled_response = data.candidates[0].content.parts[0].text;
//try to handle the stripping of spaces
if(localsettings.opmode==1 && gametext_arr.length>0 && synchro_polled_response!="")
let geminiheaders = { 'Content-Type': 'application/json' };
if(is_browser_supports_sse() && document.getElementById("geministreaming").checked)
{
synchro_polled_response = cleanup_story_completion(synchro_polled_response);
let targetep = default_gemini_base + mdlname + default_gemini_stream_suffix + custom_gemini_key;
oai_api_stream_sse(targetep,payload,geminiheaders);
}
else
{
let targetep = default_gemini_base + mdlname + default_gemini_suffix + custom_gemini_key;
gemini_api_sync_req(targetep,payload,geminiheaders);
}
else {
//error occurred, maybe captcha failed
console.error("error occurred in Gemini generation");
clear_poll_flags();
render_gametext();
msgbox("Error occurred during text generation: " + format_json_error(data));
}
})
.catch((error) => {
console.error('Error:', error);
clear_poll_flags();
render_gametext();
msgbox("Error while submitting prompt: " + error);
});
}
else if (custom_cohere_key != "")//handle for Cohere
{
@ -22727,8 +22748,15 @@ Current version indicated by LITEVER below.
</div>
<textarea title="Gemini System Prompt" class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="gemini_system_instruction" placeholder="(Enter System Instruction)"
value=""></textarea><br>
<div style="display:inline-flex">
<div>
<input type="checkbox" title="Use Gemini WebSearch" id="usegeminiweb">
<div class="box-label">Use WebSearch</div><br>
<div class="box-label">Use WebSearch</div>
</div>
<div><input type="checkbox" id="geministreaming" title="Enable SSE Streaming" onchange="">
<div class="box-label">Streaming</div>
</div>
</div>
</div>
<div id="coherecustom" class="menutext hidden">
Uses Cohere's models through their own API.<br><br>

View file

@ -0,0 +1,112 @@
ied 4 ½ months
__ggml_vocab_test__
Führer
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello world
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World
__ggml_vocab_test__
Hello World!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
Hello, world!
__ggml_vocab_test__
this is 🦙.cpp
__ggml_vocab_test__
w048 7tuijk dsdfhu
__ggml_vocab_test__
нещо на Български
__ggml_vocab_test__
កាន់តែពិសេសអាចខលចេញ
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
__ggml_vocab_test__
Hello
Hello
__ggml_vocab_test__
(
__ggml_vocab_test__
=
__ggml_vocab_test__
' era
__ggml_vocab_test__
Hello, y'all! How are you 😁 ?我想在apple工作1314151天
__ggml_vocab_test__
!!!!!!
__ggml_vocab_test__
3
__ggml_vocab_test__
33
__ggml_vocab_test__
333
__ggml_vocab_test__
3333
__ggml_vocab_test__
33333
__ggml_vocab_test__
333333
__ggml_vocab_test__
3333333
__ggml_vocab_test__
33333333
__ggml_vocab_test__
333333333
__ggml_vocab_test__
Cửa Việt
__ggml_vocab_test__
discards
__ggml_vocab_test__
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天 ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
__ggml_vocab_test__

View file

@ -0,0 +1,46 @@
2014 1032 1052 1032 28504 6972
1070 7088 1258
1032
1256
1293
1009
1010
1267
4688
1009 1010
22177 4304
45383 4304
22177 5325
45383 5325
45383 5325 1033
22177 1044 4304 1033
45383 1044 4304 1033
1593 1395 119685 1166 1153 1046 51228
1119 1048 1052 1056 1032 1055 17391 23216 30203 7785 17279
3337 30757 1902 4200 63073 3671
1225 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1225 1158 1129 1225 1158 1155 1225 1158 1133 1225 21359 1225 1158 1137
1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 1319 11234 1873 26303 1455 1934 2246 3754 10835 1041
22177
45383
1032 45383
1256 45383
1293 45383
1293 45383 1010 1293 45383
1319
1010 1376
1039 4033
22177 1044 1404 48054 1033 3075 1584 1636 119685 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749
7290 7290 7290
1051
1051 1051
1051 1051 1051
1051 1051 1051 1051
1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051 1051
1051 1051 1051 1051 1051 1051 1051 1051 1051
1067 59503 28783
3724 4058
1010 1032 1267 1032 4688 1032 17152 1458 29356 1010 1256 1010 1293 1010 1260 1010 1652 1010 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 119685 1166 1153 1240 1159 1166 1153 1032 1051 1032 1051 1051 1032 1051 1051 1051 1032 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1051 1032 1051 1046 1051 1032 1051 1791 1051 1032 1051 2880 1051 71881 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1240 1159 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749 45577 1045 6626 43555 2843 30757 1902 4200 63073 3671 14931 20040 20040 1657 1657 1975 14135 14135 83923 7290 7290 7290 45509 45509 45509 1362 6483 2151 1576 1116 2189 1514 1681 2156 1044 1576 3609 1636 5257 1063 1576 1077 1605 5257 1362 7534 3180 1494 1044 1576 1068 1636 2479 2269 26883 1063 2837 1039 45654 1261 54297 1076

View file

@ -1741,7 +1741,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "llama3" ||
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe"||
tokenizer_pre == "falcon3") {
tokenizer_pre == "falcon3" ||
tokenizer_pre == "pixtral") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;
add_bos = true;