Merge commit '65349f26f2' into concedo_experimental

# Conflicts:
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-08-19 11:35:32 +08:00
commit 140ae92886
15 changed files with 1587 additions and 118 deletions

View file

@ -8251,8 +8251,7 @@ class GptOssModel(TextModel):
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
@ModelBase.register("Lfm2ForCausalLM", "LFM2ForCausalLM")
class LFM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2
@ -8287,6 +8286,13 @@ class LFM2Model(TextModel):
self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
if is_vision_tensor:
# skip vision tensors
return []
name = name.replace("language_model.", "")
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)
@ -8294,6 +8300,41 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Lfm2VlForConditionalGeneration")
class LFM2VLModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
# TODO(tarek): for dynamic resolution image_size is not specified, setting here for compatibility
self.hparams_vision["image_size"] = 256
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2)
self.gguf_writer.add_vision_attention_layernorm_eps(self.find_vparam(["layer_norm_eps"]))
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("downsample_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
# python notation, e.g. for vision_feature_layer == -1, we pick last layer -> vision_feature_layers_to_drop = 0
vision_feature_layers_to_drop = -(self.global_config.get("vision_feature_layer", -1) + 1)
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys) - vision_feature_layers_to_drop)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
if is_vision_tensor:
# remove "model." prefix
name = name.replace("model.vision_tower.", "vision_tower.")
name = name.replace("model.multi_modal_projector.", "multi_modal_projector.")
if "patch_embedding.weight" in name:
data_torch = data_torch.view(data_torch.shape[0], 16, 16, 3).permute(0, 3, 1, 2)
return [(self.map_tensor_name(name), data_torch)]
return [] # skip other tensors
@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER

View file

@ -0,0 +1,343 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return pow(base, exph);
}
__kernel void flash_attn_f16(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
}
}
if (my_query_row < n_q) {
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
}
}
}
__kernel void flash_attn_f16_q1(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
}
}
} else if (tid == 0) {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}

View file

@ -0,0 +1,343 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define DATA_TYPE float
#define DATA_TYPE4 float4
#define CONVERT_ACC4(x) (x)
#define CONVERT_DATA4(x) (x)
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return pow(base, exph);
}
__kernel void flash_attn_f32(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
}
}
if (my_query_row < n_q) {
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
}
}
}
__kernel void flash_attn_f32_q1(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
}
}
} else if (tid == 0) {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}

View file

@ -0,0 +1,346 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
#define KV_DATA_TYPE4 half4
#define O_DATA_TYPE4 float4
#define MASK_DATA_TYPE half
#define CONVERT_Q_ACC4(x) (x)
#define CONVERT_KV_ACC4(x) convert_float4(x)
#define CONVERT_O_DATA4(x) (x)
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return pow(base, exph);
}
__kernel void flash_attn_f32_f16(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
}
}
if (my_query_row < n_q) {
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
}
__kernel void flash_attn_f32_f16_q1(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * o_void, ulong o_offset,
const float scale,
const int n_q,
const int n_kv,
const int is_causal,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void* mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
const global char* v_base = (const global char*)v_void + v_offset;
global char* o_base = (global char*)o_void + o_offset;
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
score += slope * (ACC_TYPE)mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0];
if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final;
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);
}
}
} else if (tid == 0) {
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}

View file

@ -119,6 +119,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
struct ggml_backend_vk_context;
#define MAX_PARAMETER_COUNT 8
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
struct vk_pipeline_struct {
std::string name;
@ -384,6 +386,7 @@ struct vk_device_struct {
bool float_controls_rte_fp16;
bool subgroup_add;
bool subgroup_shuffle;
bool multi_add;
bool integer_dot_product;
@ -465,6 +468,9 @@ struct vk_device_struct {
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
// indexed by num_additional_fused_ops == num_adds - 1
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
vk_pipeline pipeline_add_id_f32;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@ -817,6 +823,14 @@ struct vk_op_binary_push_constants {
float param1; float param2; int32_t param3;
};
struct vk_op_multi_add_push_constants {
// shape for dst
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
// strides for srcs+dst
uint32_t nb[8][4];
};
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne1;
@ -2403,26 +2417,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
}
#endif
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@ -2518,51 +2532,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
if (device->coopmat_acc_f16_support) {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MM
} else
@ -2647,27 +2637,27 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f16acc, matmul_id_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f16acc, matmul_id_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
@ -3034,6 +3024,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY
if (device->multi_add) {
for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
}
}
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@ -3587,6 +3583,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
vk12_features.runtimeDescriptorArray &&
device->vendor_id != VK_VENDOR_ID_INTEL &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@ -4500,7 +4502,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
return nullptr;
}
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
// XXX TODO 'prec' is not actually allowed in mul_mat_id.
bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr;
bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr;
if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc;
} else {
GGML_ASSERT(support_fp32acc);
return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
}
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@ -6931,6 +6943,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
switch (op) {
case GGML_OP_ADD:
{
if (ctx->num_additional_fused_ops > 0) {
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
}
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
}
@ -7787,6 +7802,107 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
}, dryrun);
}
static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
// Make a list of all the tensors used by the op.
// Last element of the list is the dest tensor.
const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
uint32_t num_tensors = num_srcs + 1;
GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
tensors[0] = first_node->src[0];
tensors[1] = first_node->src[1];
for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
// check whether the previous result is src[0] or src[1]
if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
} else {
tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
}
}
tensors[num_srcs] = dst;
vk_op_multi_add_push_constants pc;
pc.ne20 = (uint32_t)dst->ne[0];
pc.ne21 = (uint32_t)dst->ne[1];
pc.ne22 = (uint32_t)dst->ne[2];
pc.ne23 = (uint32_t)dst->ne[3];
for (uint32_t i = 0; i < num_tensors; ++i) {
const ggml_tensor *t = tensors[i];
pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
}
vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
if (pipeline == nullptr) {
std::cerr << "ggml_vulkan: Error: Missing multi_add";
GGML_ABORT("fatal error");
}
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
vk_buffer buf[MAX_PARAMETER_COUNT];
size_t offset[MAX_PARAMETER_COUNT];
bool uma[MAX_PARAMETER_COUNT];
for (uint32_t i = 0; i < num_tensors; ++i) {
buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
buf[i] = nullptr;
offset[i] = 0;
uma[i] = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
uma[i] = buf[i] != nullptr;
}
if (!uma[i]) {
buf[i] = buf_ctx[i]->dev_buffer;
offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
}
GGML_ASSERT(buf[i] != nullptr);
}
// If any remaining descriptors are unused, just point them at src[0]
for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
buf[i] = buf[0];
offset[i] = 0;
}
std::array<uint32_t, 3> elements;
uint32_t ne = ggml_nelements(dst);
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
}, pc, elements);
}
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@ -9747,8 +9863,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ADD:
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
}
break;
case GGML_OP_SUB:
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
@ -10630,6 +10749,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true;
}
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
if (first_node->op != GGML_OP_ADD) {
return 0;
}
if (!ctx->device->multi_add) {
return 0;
}
int32_t num_adds = 1;
while (node_idx + num_adds < cgraph->n_nodes &&
cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
num_adds < MAX_FUSED_ADDS) {
num_adds++;
}
// The shader currently requires same shapes (but different strides are allowed),
// everything f32, and no misalignment
for (int32_t i = 0; i < num_adds; ++i) {
const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
!ggml_are_same_shape(first_node, next_node->src[1]) ||
next_node->type != GGML_TYPE_F32 ||
next_node->src[0]->type != GGML_TYPE_F32 ||
next_node->src[1]->type != GGML_TYPE_F32 ||
get_misalign_bytes(ctx, next_node) ||
get_misalign_bytes(ctx, next_node->src[0]) ||
get_misalign_bytes(ctx, next_node->src[1])) {
num_adds = i;
}
}
// Verify we can fuse these
ggml_op adds[MAX_FUSED_ADDS];
for (int32_t i = 0; i < num_adds; ++i) {
adds[i] = GGML_OP_ADD;
}
// decrease num_adds if they can't all be fused
while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
num_adds--;
}
// a single add is not "fused", so just return zero
if (num_adds == 1) {
return 0;
}
return num_adds;
}
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@ -10643,8 +10814,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
@ -10719,8 +10895,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
}
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
@ -11753,6 +11934,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else {
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
}
ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
if (src1 == nullptr) {
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
@ -11837,6 +12020,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[0]->flags = src0->flags;
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
}
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;

View file

@ -210,7 +210,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@ -233,7 +233,7 @@ void main() {
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
@ -288,7 +288,7 @@ void main() {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
@ -357,7 +357,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
}
}

View file

@ -2,6 +2,7 @@
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "utils.comp"
layout (push_constant) uniform parameter
{
@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}
uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00;
i01 = (idx - i03_offset - i02_offset) / p.ne00;
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {

View file

@ -801,7 +801,7 @@ void main() {
}
#else
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) {
if (row_i < _ne1 && block + loadr_b < end_k) {
const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
@ -875,7 +875,9 @@ void main() {
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
@ -925,7 +927,9 @@ void main() {
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);

View file

@ -0,0 +1,68 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_nonuniform_qualifier : enable
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "types.comp"
#include "utils.comp"
layout (push_constant) uniform parameter2
{
// shape for dst
uint ne20; uint ne21; uint ne22; uint ne23;
// strides for srcs+dst
uint nb[8][4];
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
layout(constant_id = 0) const uint num_srcs = 2;
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
}
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
uint nb20 = p.nb[num_srcs][0];
uint nb21 = p.nb[num_srcs][1];
uint nb22 = p.nb[num_srcs][2];
uint nb23 = p.nb[num_srcs][3];
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
}
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
FLOAT_TYPE sum = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
}
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
}

View file

@ -0,0 +1,25 @@
#ifndef UTILS_COMP
#define UTILS_COMP
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}
uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) {
i03 = fastdiv(idx, (ne02*ne01*ne00));
const uint i03_offset = i03 * ne02*ne01*ne00;
i02 = fastdiv((idx - i03_offset), (ne01*ne00));
const uint i02_offset = i02*ne01*ne00;
i01 = (idx - i03_offset - i02_offset) / ne00;
i00 = idx - i03_offset - i02_offset - i01*ne00;
}
#endif // UTILS_COMP

View file

@ -691,6 +691,8 @@ void process_shaders() {
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
for (auto &c : compiles) {
c.wait();
}

View file

@ -2832,6 +2832,7 @@ class VisionProjectorType:
QWEN2A = "qwen2a" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
# Items here are (block size, type size)

View file

@ -1272,6 +1272,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"pre_mm_projector_norm",
),

View file

@ -82,6 +82,7 @@
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_NORM_B "mm.input_norm.bias"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
@ -133,6 +134,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_UNKNOWN,
};
@ -153,6 +155,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View file

@ -283,6 +283,7 @@ struct clip_model {
// LLaVA projection
ggml_tensor * mm_input_norm_w = nullptr;
ggml_tensor * mm_input_norm_b = nullptr;
ggml_tensor * mm_0_w = nullptr;
ggml_tensor * mm_0_b = nullptr;
ggml_tensor * mm_2_w = nullptr;
@ -513,11 +514,17 @@ struct clip_graph {
ggml_cgraph * build_siglip() {
ggml_tensor * inp = build_inp();
ggml_tensor * learned_pos_embd = model.position_embeddings;
if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
learned_pos_embd = resize_position_embeddings();
}
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
model.position_embeddings,
learned_pos_embd,
nullptr);
if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
@ -567,6 +574,45 @@ struct clip_graph {
bsz);
cur = ggml_mul_mat(ctx0, model.projection, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block
const int scale_factor = model.hparams.proj_scale_factor;
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
if (pad_width || pad_height) {
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
width += pad_width;
height += pad_height;
}
// unshuffle h
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
// unshuffle w
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3));
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
// projection
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
} else {
GGML_ABORT("SigLIP: Unsupported projector type");
}
@ -1585,6 +1631,27 @@ private:
}
}
// siglip2 naflex
ggml_tensor * resize_position_embeddings() {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
if (!pos_embd || height * width == pos_embd->ne[1]) {
return pos_embd;
}
const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
pos_embd = ggml_cont(ctx0, pos_embd);
return pos_embd;
}
// build vision transformer (ViT) cgraph
// this function should cover most of the models
// if your model has specific features, you should probably duplicate this function
@ -1991,6 +2058,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_LFM2:
{
res = graph.build_siglip();
} break;
@ -2276,6 +2344,7 @@ struct clip_model_loader {
}
} break;
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_INTERNVL:
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
@ -2584,6 +2653,15 @@ struct clip_model_loader {
{
model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_LFM2:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
@ -3604,6 +3682,43 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_y = inst.grid_size.height;
return true;
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
GGML_ASSERT(params.proj_scale_factor);
// smart resize
const int width = img->nx;
const int height = img->ny;
const int total_factor = params.patch_size * params.proj_scale_factor;
constexpr int min_image_tokens = 64;
constexpr int max_image_tokens = 256;
const float min_pixels = min_image_tokens * total_factor * total_factor;
const float max_pixels = max_image_tokens * total_factor * total_factor;
auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
auto ceil_by_factor = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
int h_bar = std::max(total_factor, round_by_factor(height));
int w_bar = std::max(total_factor, round_by_factor(width));
if (h_bar * w_bar > max_pixels) {
const auto beta = std::sqrt((height * width) / max_pixels);
h_bar = std::max(total_factor, floor_by_factor(height / beta));
w_bar = std::max(total_factor, floor_by_factor(width / beta));
} else if (h_bar * w_bar < min_pixels) {
const auto beta = std::sqrt(min_pixels / (height * width));
h_bar = ceil_by_factor(height * beta);
w_bar = ceil_by_factor(width * beta);
}
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
clip_image_u8 resized_img;
image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
@ -3806,6 +3921,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
n_patches_sq /= 2;
}
} break;
case PROJECTOR_TYPE_LFM2:
{
n_patches_sq = (img->nx / (params.patch_size * params.proj_scale_factor)) * (img->ny / (params.patch_size * params.proj_scale_factor));
} break;
default:
GGML_ABORT("unsupported projector type");
}
@ -4210,6 +4329,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_VOXTRAL:
{
// do nothing
@ -4491,6 +4611,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_model_proj->ne[1];
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");
}