Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	examples/eval-callback/eval-callback.cpp
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/gelu.cl
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2025-07-07 17:46:58 +08:00
commit a17c79b1a9
10 changed files with 620 additions and 553 deletions

View file

@ -517,6 +517,8 @@ struct vk_device_struct {
ggml_backend_buffer_type buffer_type; ggml_backend_buffer_type buffer_type;
bool disable_fusion;
#ifdef GGML_VULKAN_MEMORY_DEBUG #ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger; std::unique_ptr<vk_memory_logger> memory_logger;
#endif #endif
@ -652,6 +654,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nem2; uint32_t nem2;
uint32_t nem3;
uint32_t nb01; uint32_t nb01;
uint32_t nb02; uint32_t nb02;
@ -667,8 +670,7 @@ struct vk_flash_attn_push_constants {
float max_bias; float max_bias;
float logit_softcap; float logit_softcap;
uint32_t mask; uint32_t mask_n_head_log2;
uint32_t n_head_log2;
float m0; float m0;
float m1; float m1;
@ -1107,8 +1109,8 @@ static size_t vk_skip_checks;
static size_t vk_output_tensor; static size_t vk_output_tensor;
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
static void ggml_vk_check_results_0(ggml_tensor * tensor); static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
static void ggml_vk_check_results_1(ggml_tensor * tensor); static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
#endif #endif
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@ -3531,6 +3533,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->idx = idx; device->idx = idx;
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
return device; return device;
} }
@ -6135,6 +6139,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;
const uint32_t HSK = nek0; const uint32_t HSK = nek0;
const uint32_t HSV = nev0; const uint32_t HSV = nev0;
@ -6202,7 +6207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce // grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2. // and change addressing calculations to index Q's dimension 2.
@ -6372,17 +6377,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
const vk_flash_attn_push_constants pc = { N, KV, const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3, (uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3, (uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3, (uint32_t)nev2, (uint32_t)nev3,
nem1, nem2, nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap, scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1, mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k }; gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
@ -7675,8 +7682,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
} }
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type); const uint32_t dst_type_size = ggml_type_size(dst->type);
@ -8906,7 +8912,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
} }
} }
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
// Returns true if node has enqueued work into the queue, false otherwise // Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@ -9167,9 +9173,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// fused rms_norm + mul // fused rms_norm + mul
ggml_tensor *mul = cgraph->nodes[node_idx + 1]; ggml_tensor *mul = cgraph->nodes[node_idx + 1];
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun); ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
} else { } else {
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun); ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
} }
break; break;
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
@ -9329,7 +9335,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->compute_ctx.reset(); ctx->compute_ctx.reset();
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
if (!ok) { if (!ok) {
if (node->op == GGML_OP_UNARY) { if (node->op == GGML_OP_UNARY) {
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl; std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@ -9344,7 +9350,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return true; return true;
} }
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
GGML_UNUSED(cgraph);
ggml_backend_buffer * buf = nullptr; ggml_backend_buffer * buf = nullptr;
switch (tensor->op) { switch (tensor->op) {
@ -9454,7 +9461,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
// Only run if ctx hasn't been submitted yet // Only run if ctx hasn't been submitted yet
if (!subctx->seqs.empty()) { if (!subctx->seqs.empty()) {
#ifdef GGML_VULKAN_CHECK_RESULTS #ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_0(tensor); ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
use_fence = true; use_fence = true;
#endif #endif
@ -9474,7 +9481,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
ggml_vk_wait_for_fence(ctx); ggml_vk_wait_for_fence(ctx);
} }
#ifdef GGML_VULKAN_CHECK_RESULTS #ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_1(tensor); ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
#endif #endif
} }
@ -9921,6 +9928,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
} }
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
// additional constraints specific to this fusion
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@ -9934,7 +9972,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
uint64_t total_mat_mul_bytes = 0; uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1; ctx->num_additional_fused_ops = 1;
} }
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@ -10004,7 +10042,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
} }
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1; ctx->num_additional_fused_ops = 1;
} }
@ -10327,12 +10365,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
// It's straightforward to support different K/V dequant, but would // It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines // significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) { if (op->src[1]->type != op->src[2]->type) {
@ -10787,11 +10819,21 @@ void * comp_result;
size_t comp_size; size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS]; size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0; size_t check_counter = 0;
static void ggml_vk_check_results_0(ggml_tensor * tensor) { static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) { if (tensor->op == GGML_OP_TRANSPOSE) {
return; return;
} }
bool fused_rms_norm_mul = false;
int rms_norm_idx = -1;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
check_counter++; check_counter++;
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return; return;
@ -10819,6 +10861,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
ggml_tensor * srci = tensor->src[i]; ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
switch (i) {
case 0: srci = rms_norm->src[0]; break;
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
default: continue;
}
}
if (srci == nullptr) { if (srci == nullptr) {
continue; continue;
} }
@ -10876,7 +10927,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_SUB) { } else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) { } else if (tensor->op == GGML_OP_MUL) {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); if (fused_rms_norm_mul) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
} else {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_DIV) { } else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) { } else if (tensor->op == GGML_OP_CONCAT) {
@ -11067,10 +11123,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph, tensor_clone); ggml_build_forward_expand(cgraph_cpu, tensor_clone);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(tensor_clone, "tensor_clone"); ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@ -11093,10 +11149,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
} }
static void ggml_vk_check_results_1(ggml_tensor * tensor) { static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) { if (tensor->op == GGML_OP_TRANSPOSE) {
return; return;
} }
bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return; return;
} }

View file

@ -101,8 +101,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif #endif
uint32_t m_offset = 0; uint32_t m_offset = 0;
if (p.nem2 != 1) { if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV; m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
} }
[[dont_unroll]] [[dont_unroll]]
@ -149,7 +149,7 @@ void main() {
} }
} }
if (p.mask != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;

View file

@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nem2; uint32_t nem2;
uint32_t nem3;
uint32_t nb01; uint32_t nb01;
uint32_t nb02; uint32_t nb02;
@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
float max_bias; float max_bias;
float logit_softcap; float logit_softcap;
uint32_t mask; uint32_t mask_n_head_log2;
uint32_t n_head_log2;
float m0; float m0;
float m1; float m1;
@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num; uint32_t k_num;
} p; } p;
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16) #if defined(A_TYPE_PACKED16)
@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{ {
const uint32_t h = iq2 + (r % p.gqa_ratio); const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph))); return ACC_TYPE(pow(base, ACC_TYPE(exph)));
} }

View file

@ -126,8 +126,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif #endif
uint32_t m_offset = 0; uint32_t m_offset = 0;
if (p.nem2 != 1) { if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV; m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
} }
[[dont_unroll]] [[dont_unroll]]
@ -182,7 +182,7 @@ void main() {
barrier(); barrier();
} }
if (p.mask != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;

View file

@ -131,8 +131,8 @@ void main() {
} }
uint32_t m_offset = 0; uint32_t m_offset = 0;
if (p.nem2 != 1) { if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
} }
[[dont_unroll]] [[dont_unroll]]
@ -153,7 +153,7 @@ void main() {
} }
} }
if (p.mask != 0) { if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

View file

@ -500,10 +500,9 @@ void main() {
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx % 128) / 4; const uint ib8 = idx % 32;
const int i8 = 2 * int(idx % 4);
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qh = data_a[ib].qh[ib32]; const uint qh = data_a[ib].qh[ib32];
@ -512,22 +511,16 @@ void main() {
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
const ivec2 gvec = ivec2( [[unroll]] for (int k = 0; k < 8; ++k) {
bitfieldExtract(grid, 2 * (i8), 2), buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
bitfieldExtract(grid, 2 * (i8 + 1), 2) }
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ1_M) #elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib8 = (idx % 128) / 4; const uint ib8 = idx % 32;
const uint ib16 = ib8 / 2; const uint ib16 = ib8 / 2;
const int i8 = 2 * int(idx % 4);
const uint16_t[4] scales = data_a[ib].scales; const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@ -538,21 +531,17 @@ void main() {
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x); [[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ2_XXS) #elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx / 4) % 4; const uint ib8 = idx % 4;
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8]; const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@ -562,63 +551,81 @@ void main() {
data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7] data_a[ib].qs[8*ib32 + 7]
)); ));
const float db = d * 0.25 * (0.5 + (signs >> 28)); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = sign7 | (bitCount(sign7) << 7);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2xxs_grid[qs];
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS) #elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7 const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = (idx / 4) % 4; // 0..3 const uint ib8 = idx % 4; // 0..3
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const float db = d * 0.25 * (0.5 + scale); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9; const uint sign7 = qs >> 9;
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = sign7 | (bitCount(sign7) << 7);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2xs_grid[qs & 511];
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S) #elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 32; // 8 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31 const uint ib8 = idx % 32; // 0..31
const uint ib32 = ib8 / 4; // 0..7 const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8]; const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32]; const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4); const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const float db = d * 0.25 * (0.5 + scale); const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; const vec4 grid0 = vec4(unpack8(grid.x));
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS) #elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 128) / 2; // 0..63 const uint iqs = idx % 64; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
@ -631,33 +638,36 @@ void main() {
)); ));
const float db = d * 0.5 * (0.5 + (signs >> 28)); const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs];
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); const vec4 v = db * vec4(unpack8(grid));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S) #elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 128) / 2; // 0..63 const uint iqs = idx % 64; // 0..63
const uint iqh = iqs / 8; const uint iqh = iqs / 8;
const float d = float(data_a[ib].d); const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs]; const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh]; const uint qh = data_a[ib].qh[iqh];
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
const uint scale = data_a[ib].scales[iqs / 16]; const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS) #elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

View file

@ -374,9 +374,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string load_vec_quant = "2"; std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1")) if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8"; load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
load_vec_quant = "4"; load_vec_quant = "4";
if (tname == "bf16") { if (tname == "bf16") {

File diff suppressed because it is too large Load diff

View file

@ -132,6 +132,28 @@ def test_chat_template():
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@pytest.mark.parametrize("prefill,re_prefill", [
("Whill", "Whill"),
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
])
def test_chat_template_assistant_prefill(prefill, re_prefill):
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
{"role": "assistant", "content": prefill},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
def test_apply_chat_template(): def test_apply_chat_template():
global server global server
server.chat_template = "command-r" server.chat_template = "command-r"
@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
[{"role": "system", "content": 123}], [{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case # [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}], [{"role": "system", "content": "test"}, {}],
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
]) ])
def test_invalid_chat_completion_req(messages): def test_invalid_chat_completion_req(messages):
global server global server

View file

@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
/* Append assistant prefilled message */ /* Append assistant prefilled message */
if (prefill_assistant_message) { if (prefill_assistant_message) {
chat_params.prompt += last_message.content; if (!last_message.content_parts.empty()) {
for (auto & p : last_message.content_parts) {
chat_params.prompt += p.text;
}
} else {
chat_params.prompt += last_message.content;
}
} }
llama_params["chat_format"] = static_cast<int>(chat_params.format); llama_params["chat_format"] = static_cast<int>(chat_params.format);