Merge commit '7aaeedc098' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	docs/ops.md
#	docs/ops/SYCL.csv
#	docs/ops/Vulkan.csv
#	ggml/src/ggml-cann/acl_tensor.cpp
#	ggml/src/ggml-cann/acl_tensor.h
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/aclnn_ops.h
#	ggml/src/ggml-cann/common.h
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/rms_norm.cl
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	tests/test-backend-ops.cpp
This commit is contained in:
LostRuins Concedo 2025-11-17 23:33:09 +08:00
commit 008405e8fd
31 changed files with 1487 additions and 339 deletions

View file

@ -7840,12 +7840,6 @@ class Glm4MoeModel(TextModel):
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
# Patch broken chat template
if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
special_vocab.chat_template = special_vocab.chat_template.replace(
"""{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
"""{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
@ -9395,16 +9389,6 @@ class HunYuanModel(TextModel):
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
def set_vocab(self):
super().set_vocab()
# remove unsupported array slicing in chat template
# ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
if tokenizer.chat_template is not None:
chat_template = tokenizer.chat_template.replace("[:]", "")
self.gguf_writer.add_chat_template(chat_template)
@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):

View file

@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_CUMSUM);
char base[256];
char name[256];
snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);

View file

@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);

View file

@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_CUMSUM:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
@ -988,7 +989,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
}
case GGML_TYPE_I32:
return op->type == GGML_TYPE_F32;
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
default:
return false;
};

View file

@ -612,6 +612,45 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_sum_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t net0;
int64_t net1;
int64_t net2;
int64_t net3;
uint64_t nbt0;
uint64_t nbt1;
uint64_t nbt2;
uint64_t nbt3;
bool outb;
} ggml_metal_kargs_cumsum_blk;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t net0;
int64_t net1;
int64_t net2;
int64_t net3;
uint64_t nbt0;
uint64_t nbt1;
uint64_t nbt2;
uint64_t nbt3;
} ggml_metal_kargs_cumsum_add;
typedef struct {
int32_t ne00;
int32_t ne01;

View file

@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_sum_rows(ctx, idx);
} break;
case GGML_OP_CUMSUM:
{
n_fuse = ggml_metal_op_cumsum(ctx, idx);
} break;
case GGML_OP_SOFT_MAX:
{
n_fuse = ggml_metal_op_soft_max(ctx, idx);
@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
if (op->src[1]) {
GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
//if (src1) {
// [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
//} else {
// [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
//}
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setBytes:&args length:sizeof(args) atIndex:3];
//[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
//[encoder setComputePipelineState:pipeline];
//[encoder setBytes:&args length:sizeof(args) atIndex:0];
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
int nth = 1;
while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
nth *= 2;
}
GGML_ASSERT(ne00 <= nth*nth);
const int64_t net0 = (ne00 + nth - 1) / nth;
const int64_t net1 = ne01;
const int64_t net2 = ne02;
const int64_t net3 = ne03;
const uint64_t nbt0 = sizeof(float);
const uint64_t nbt1 = net0*nbt0;
const uint64_t nbt2 = net1*nbt1;
const uint64_t nbt3 = net2*nbt2;
const size_t smem = GGML_PAD(32*sizeof(float), 16);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
{
ggml_metal_kargs_cumsum_blk args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
/*.outb =*/ ne00 > nth,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
}
if (ne00 > nth) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_kargs_cumsum_blk args = {
/*.ne00 =*/ net0,
/*.ne01 =*/ net1,
/*.ne02 =*/ net2,
/*.ne03 =*/ net3,
/*.nb00 =*/ nbt0,
/*.nb01 =*/ nbt1,
/*.nb02 =*/ nbt2,
/*.nb03 =*/ nbt3,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
/*.outb =*/ false,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
}
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
ggml_metal_kargs_cumsum_add args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.net0 =*/ net0,
/*.net1 =*/ net1,
/*.net2 =*/ net2,
/*.net3 =*/ net3,
/*.nbt0 =*/ nbt0,
/*.nbt1 =*/ nbt1,
/*.nbt2 =*/ nbt2,
/*.nbt3 =*/ nbt3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline_add);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
}
}
return 1;
}
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float max_bias;
@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_ssm_conv args = {
/*.ne00 =*/ ne00,
@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const ggml_tensor * src3 = op->src[3];
const ggml_tensor * src4 = op->src[4];
@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
const int64_t T = op->src[0]->ne[2];
@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t * opts = op->op_params;
ggml_op_pool op_pool = (ggml_op_pool) opts[0];
@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ne00 == ne10);
@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
// src2 = ids
GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
@ -2191,8 +2318,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
@ -2222,8 +2347,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
@ -2363,8 +2486,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {
@ -2695,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
@ -2743,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t ngrp = ((const int32_t *) op->op_params)[0];
@ -2798,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
@ -2934,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
@ -3028,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
const int32_t s1 = ((const int32_t *)(op->op_params))[1];
@ -3178,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
@ -3223,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
@ -3277,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float sf0 = (float)ne0/op->src[0]->ne[0];
const float sf1 = (float)ne1/op->src[0]->ne[1];
@ -3330,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_pad args = {
/*.ne00 =*/ ne00,
@ -3374,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_pad_reflect_1d args = {
/*.ne00 =*/ ne00,
@ -3418,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float start;
float step;
@ -3436,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
//[encoder setComputePipelineState:pipeline];
//[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
//[encoder setBytes:&args length:sizeof(args) atIndex:1];
//[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
@ -3460,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const int dim = op->op_params[0];
const int max_period = op->op_params[1];
@ -3494,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_kargs_argmax args = {
/*.ne00 = */ ne00,
@ -3535,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
@ -3545,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
nth *= 2;
}
const int nptg = (ne00 + nth - 1)/nth;
const int npr = (ne00 + nth - 1)/nth;
// Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
@ -3557,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
std::swap(bid_dst, bid_tmp);
}
@ -3579,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
@ -3611,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
std::swap(bid_dst, bid_tmp);
@ -3632,7 +3745,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
@ -3668,7 +3781,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
@ -3704,7 +3817,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);

View file

@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);

View file

@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
case GGML_OP_CUMSUM:
case GGML_OP_ARGSORT:
{
res *= 2;

View file

@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
@ -4543,7 +4654,7 @@ typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * smem_i32 [[threadgroup(0)]],
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32(
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
smem_i32[col] = i00 + col;
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32(
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (smem_i32[col] >= args.ne00 ||
(smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(smem_i32[col], smem_i32[ixj]);
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (smem_i32[ixj] >= args.ne00 ||
(smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(smem_i32[col], smem_i32[ixj]);
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32(
if (i00 + col < args.ne00) {
dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
dst[col] = smem_i32[col];
dst[col] = shmem_i32[col];
}
}
@ -4628,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int im = tgpig[0] / args.ne01;
int i01 = tgpig[0] % args.ne01;
int i02 = tgpig[1];
int i03 = tgpig[2];
const int start = im * (2*args.len);
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
@ -4657,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
+ args.nb02*i02
+ args.nb03*i03);
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
// find partition (i,j) such that i+j = k
int low = k > len1 ? k - len1 : 0;
int high = MIN(k, len0);
if (total == 0) {
return;
}
while (low < high) {
const int mid = (low + high) >> 1;
const int chunk = (total + ntg.x - 1) / ntg.x;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k - mid - 1];
const int k0 = tpitg.x * chunk;
const int k1 = min(k0 + chunk, total);
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (k0 >= total) {
return;
}
if (order == GGML_SORT_ORDER_ASC) {
if (val0 <= val1) {
low = mid + 1;
} else {
high = mid;
}
} else {
if (val0 >= val1) {
low = mid + 1;
} else {
high = mid;
}
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
const int i = low;
const int j = k - i;
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
out_idx = tmp1[j];
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
out_idx = tmp0[i];
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
const int32_t idx0 = tmp0[i];
const int32_t idx1 = tmp1[j];
bool take_left;
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
out_idx = (order == GGML_SORT_ORDER_ASC)
? (val0 <= val1 ? idx0 : idx1)
: (val0 >= val1 ? idx0 : idx1);
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
@ -6401,6 +6560,7 @@ template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif

View file

@ -0,0 +1,273 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#define LM_FIRST_256B 0
#define LM_SECOND_256B 64
#define LM_THIRD_256B 128
#define LM_FOURTH_256B 192
inline float16 mm_load_a(
image1d_buffer_t matrix_A,
uint subMatrixAStartInElements,
int nb01,
int line_stride_matrix_A_in_bytes
) {
__private float8 regA;
size_t sub_block_id_m = get_local_id(0);
#ifdef KQV
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
#else // KQ
uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
#endif
regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
return convert_float16(as_half16(regA));
}
inline float4 alu_32(
float16 regA,
__local float4* matrix_B_vec
) {
__private float4 rC = 0;
int i = get_sub_group_id() * 64;
rC += regA.s0 * matrix_B_vec[i];
rC += regA.s1 * matrix_B_vec[i + 16];
rC += regA.s4 * matrix_B_vec[i + 1];
rC += regA.s5 * matrix_B_vec[i + 17];
rC += regA.s8 * matrix_B_vec[i + 2];
rC += regA.s9 * matrix_B_vec[i + 18];
rC += regA.sc * matrix_B_vec[i + 3];
rC += regA.sd * matrix_B_vec[i + 19];
i += 32;
rC += regA.s2 * matrix_B_vec[i];
rC += regA.s3 * matrix_B_vec[i + 16];
rC += regA.s6 * matrix_B_vec[i + 1];
rC += regA.s7 * matrix_B_vec[i + 17];
rC += regA.sa * matrix_B_vec[i + 2];
rC += regA.sb * matrix_B_vec[i + 18];
rC += regA.se * matrix_B_vec[i + 3];
rC += regA.sf * matrix_B_vec[i + 19];
return rC;
}
inline float16 alu_16(
float16 regA,
__local float* matrix_B_local
) {
float16 out;
__local float4* matrix_B_vec = (__local float4*)matrix_B_local;
out.s0123 = alu_32(regA, matrix_B_vec);
out.s4567 = alu_32(regA, matrix_B_vec + 4);
out.s89ab = alu_32(regA, matrix_B_vec + 8);
out.scdef = alu_32(regA, matrix_B_vec + 12);
return out;
}
inline void mm_mad(
__local float* matrix_B_local,
float16 regA,
float8 regB,
uint b_localOffsetInWords,
float16* regC0_ptr,
float16* regC1_ptr
) {
int offset = b_localOffsetInWords + get_sub_group_id() * 256;
matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
float16 add0 = alu_16(regA, matrix_B_local);
*regC0_ptr += add0;
matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
float16 add1 = alu_16(regA, matrix_B_local);
*regC1_ptr += add1;
}
inline void mm_store_c_N(
__write_only image1d_buffer_t matrix_C,
float16 regC0,
float16 regC1,
uint subMatrixCStartInElements,
int line_stride_matrix_C_in_bytes,
int mask
) {
size_t sub_block_id_m = get_local_id(0);
uint strideInWords = line_stride_matrix_C_in_bytes/4;
uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
}
#define TILESIZE_K 16
#define TILESIZE_M 64
#define TILESIZE_N 32
#ifdef KQV
__kernel void mul_mm_f16_f32_kqv(
#else
__kernel void mul_mm_f16_f32_kq(
#endif
__read_only image1d_buffer_t matrix_A,
int offset0,
__global float* matrix_B,
int offset1,
__write_only image1d_buffer_t matrix_C,
int offsetd,
int M, int K, int N,
int D_A,
int D_B,
int nb01
) {
uint block_id_m = get_global_id(1);
uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
__private float16 regA;
__private float8 regB;
__private float16 regC0;
__private float16 regC1;
const uint col = block_id_m * TILESIZE_M;
const uint row = block_id_n * TILESIZE_N;
const uint depth_A = block_id_d / (D_B/D_A);
const uint depth_B = block_id_d;
#ifdef KQV
int line_stride_matrix_A_in_bytes = nb01 * M;
int line_stride_matrix_B_in_bytes = K * N * 4;
#else
int line_stride_matrix_A_in_bytes = K * D_A * 2;
int line_stride_matrix_B_in_bytes = K * D_B * 4;
#endif
int line_stride_matrix_C_in_bytes = M * 4;
const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
size_t sub_block_id_m = get_local_id(0);
uint b_localOffsetInWords = (sub_block_id_m/16)*16
+ ((((sub_block_id_m)>>0)&1)<<2)
+ ((((sub_block_id_m)>>1)&1)<<3)
+ ((((sub_block_id_m)>>2)&1)<<0)
+ ((((sub_block_id_m)>>3)&1)<<1);
uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
uint b_globalOffsetInWords00, b_globalOffsetInWords16;
#ifdef KQV
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
#else
b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
#endif
__local float matrix_B_local[1024];
for (uint step=0; step < K; step+=TILESIZE_K) {
size_t sub_block_id_m = get_local_id(0);
regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, &regC0, &regC1);
subMatrixAStartInElements += TILESIZE_K;
subMatrixBStartInElements += TILESIZE_K;
}
uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
}

View file

@ -44,6 +44,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#include <memory>
#include <limits>
#include <map>
#include <set>
#include <unordered_map>
#include <memory>
#include <mutex>
@ -644,6 +645,7 @@ struct vk_device_struct {
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
@ -840,6 +842,12 @@ struct vk_mat_mat_push_constants {
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
};
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
struct vk_mat_vec_push_constants {
uint32_t ncols;
uint32_t stride_a;
@ -848,8 +856,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
uint32_t enable_scale;
uint32_t fusion_flags;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
@ -863,7 +870,7 @@ struct vk_mat_vec_p021_push_constants {
uint32_t nchannels_y;
uint32_t b_offset;
uint32_t d_offset;
uint32_t enable_bias;
uint32_t fusion_flags;
};
struct vk_mat_vec_nc_push_constants {
@ -879,7 +886,7 @@ struct vk_mat_vec_nc_push_constants {
uint32_t nb03;
uint32_t nb13;
uint32_t nb23;
uint32_t enable_bias;
uint32_t fusion_flags;
};
struct vk_mat_mat_id_push_constants {
@ -897,8 +904,7 @@ struct vk_mat_vec_id_push_constants {
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
uint32_t enable_scale;
uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
};
@ -3481,8 +3487,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
static constexpr uint32_t mul_mat_vec_num_bindings = 4;
static constexpr uint32_t mul_mat_vec_id_num_bindings = 5;
static constexpr uint32_t mul_mat_vec_num_bindings = 5;
static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
@ -3803,6 +3809,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -6469,7 +6477,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@ -6756,7 +6764,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
@ -6901,21 +6909,31 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
groups_x = CEIL_DIV(groups_x, groups_z);
}
uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
uint32_t fusion_flags = 0;
vk_subbuffer d_B = d_D;
if (enable_bias) {
vk_subbuffer d_F0 = d_D;
if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
d_B = ggml_vk_tensor_subbuffer(ctx, bias);
d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
}
vk_subbuffer d_F1 = d_D;
if (ctx->num_additional_fused_ops == 2) {
const ggml_tensor * add = cgraph->nodes[node_idx + 2];
const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@ -6923,7 +6941,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
d_X,
d_Y,
d_D,
d_B,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
@ -6976,22 +6995,31 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
vk_subbuffer d_B = d_D;
vk_subbuffer d_F0 = d_D;
uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
uint32_t fusion_flags = 0;
if (enable_bias) {
if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
d_B = ggml_vk_tensor_subbuffer(ctx, bias);
d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
}
vk_subbuffer d_F1 = d_D;
if (ctx->num_additional_fused_ops > 1) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
vk_mat_vec_p021_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
0, 0, enable_bias
0, 0, fusion_flags
};
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@ -7007,7 +7035,8 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
d_Qx,
d_Qy,
d_D,
d_B,
d_F0,
d_F1,
}, pc, { 1, (uint32_t)ne01, workgroups_z });
}
@ -7059,15 +7088,24 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
vk_subbuffer d_B = d_D;
vk_subbuffer d_F0 = d_D;
uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
uint32_t fusion_flags = 0;
if (enable_bias) {
if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
d_B = ggml_vk_tensor_subbuffer(ctx, bias);
d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
}
vk_subbuffer d_F1 = d_D;
if (ctx->num_additional_fused_ops > 1) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
@ -7076,7 +7114,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
row_stride_x, channel_stride_x, channel_stride_y,
(uint32_t)(ne12 / ne02), (uint32_t)ne12,
0, 0,
nb03, nb13, nb23, enable_bias
nb03, nb13, nb23, fusion_flags
};
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@ -7086,7 +7124,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
d_Qx,
d_Qy,
d_D,
d_B,
d_F0,
d_F1,
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
@ -7214,7 +7253,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
@ -7507,7 +7546,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
vk_subbuffer d_B = d_D;
vk_subbuffer d_F0 = d_D;
vk_subbuffer d_X, d_Y;
if (qx_needs_dequant) {
@ -7560,30 +7599,34 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
groups_x = CEIL_DIV(groups_x, groups_z);
}
uint32_t enable_bias = 0;
uint32_t enable_scale = 0;
uint32_t fusion_flags = 0;
if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
enable_scale = 1;
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
} else {
GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
enable_bias = 1;
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
}
}
if (enable_bias || enable_scale) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
vk_subbuffer d_F1 = d_D;
if (ctx->num_additional_fused_ops > 1) {
const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
d_B = ggml_vk_tensor_subbuffer(ctx, bias);
d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
enable_bias, enable_scale,
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@ -7591,7 +7634,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
d_X,
d_Y,
d_D,
d_B,
d_F0,
d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
@ -8115,6 +8159,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_LOG:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
@ -8523,6 +8573,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
@ -8795,6 +8846,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@ -9403,6 +9455,10 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
}
static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
@ -11198,6 +11254,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@ -11422,6 +11479,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COS:
ggml_vk_cos(ctx, compute_ctx, src0, node);
break;
case GGML_OP_LOG:
ggml_vk_log(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@ -11692,6 +11753,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@ -12335,10 +12397,7 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
return false;
}
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
// additional constraints specific to this fusion
const ggml_tensor *mul = cgraph->nodes[node_idx];
const ggml_tensor *add = cgraph->nodes[node_idx + 1];
auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
// mat-vec only
@ -12358,8 +12417,60 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
if (get_misalign_bytes(ctx, bias) != 0) {
return false;
}
return true;
};
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
// additional constraints specific to this fusion
const ggml_tensor *mul = cgraph->nodes[node_idx];
const ggml_tensor *add = cgraph->nodes[node_idx + 1];
if (!mm_add_ok(mul, add)) {
return false;
}
if (ops.size() == 3) {
if (ops.begin()[2] != GGML_OP_ADD) {
return false;
}
if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
return false;
}
}
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
const ggml_tensor *scale = mul->src[1];
if (mmid != mul->src[0]) {
return false;
}
// mat-vec only
if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
return false;
}
// shaders assume the types match
if (mmid->type != scale->type) {
return false;
}
// shaders assume the bias is contiguous
if (!ggml_is_contiguous(scale)) {
return false;
}
// unaligned bias isn't handled
if (get_misalign_bytes(ctx, scale) != 0) {
return false;
}
// shader only indexes by expert index
if (scale->ne[0] != 1 ||
scale->ne[1] != mul->ne[1] ||
scale->ne[2] != 1 ||
scale->ne[3] != 1) {
return false;
}
return true;
};
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
// additional constraints specific to this fusion
const ggml_tensor *mul = cgraph->nodes[node_idx];
const ggml_tensor *add = cgraph->nodes[node_idx + 1];
@ -12388,38 +12499,22 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
if (get_misalign_bytes(ctx, bias) != 0) {
return false;
}
if (ops.size() == 3) {
if (ops.begin()[2] != GGML_OP_MUL) {
return false;
}
const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
return mmid_mul_ok(add, mul);
}
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
// additional constraints specific to this fusion
const ggml_tensor *mmid = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
const ggml_tensor *scale = mul->src[1];
if (mmid != mul->src[0]) {
return false;
}
// mat-vec only
if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
return false;
}
// shaders assume the types match
if (mmid->type != scale->type) {
return false;
}
// shaders assume the bias is contiguous
if (!ggml_is_contiguous(scale)) {
return false;
}
// unaligned bias isn't handled
if (get_misalign_bytes(ctx, scale) != 0) {
return false;
}
// shader only indexes by expert index
if (scale->ne[0] != 1 ||
scale->ne[1] != mul->ne[1] ||
scale->ne[2] != 1 ||
scale->ne[3] != 1) {
if (!mmid_mul_ok(mmid, mul)) {
return false;
}
}
@ -12734,8 +12829,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
ctx->num_additional_fused_ops = 2;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 2;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
@ -12902,6 +13001,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
std::set<ggml_tensor *> used_node_set;
int first_unused = 0;
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
@ -12924,6 +13025,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (match_pattern(pattern, first_unused)) {
for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used_node_set.insert(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
@ -13027,6 +13129,36 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
used[set_rows_idx] = true;
}
}
// Look for MUL_MAT_ID + ADD_ID + MUL
if (j > 0 &&
graph->nodes[j]->op == GGML_OP_ADD_ID &&
graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
if (graph->nodes[k]->op == GGML_OP_MUL &&
graph->nodes[k]->src[0] == graph->nodes[j] &&
// src1 must either be weights or already processed
(graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
current_set.push_back(k);
used[k] = true;
break;
}
}
}
// Look for MUL_MAT + ADD + ADD
if (j > 0 &&
graph->nodes[j]->op == GGML_OP_ADD &&
graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
if (graph->nodes[k]->op == GGML_OP_ADD &&
graph->nodes[k]->src[0] == graph->nodes[j] &&
// src1 must either be weights or already processed
(graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
current_set.push_back(k);
used[k] = true;
break;
}
}
}
}
}
// Second pass grabs view nodes.
@ -13059,6 +13191,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
// Push the current set into new_order
for (auto c : current_set) {
new_order.push_back(graph->nodes[c]);
used_node_set.insert(graph->nodes[c]);
used[c] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
@ -13582,6 +13715,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols;
case GGML_OP_UPSCALE:
@ -14077,6 +14212,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COS) {
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);

View file

@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
}

View file

@ -11,29 +11,7 @@
#define EXPERT_COUNT 8
#endif
#include "types.glsl"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#ifdef MUL_MAT_ID
layout (binding = 4) readonly buffer IDS {int data_ids[];};
#endif
#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
@ -48,8 +26,7 @@ layout (push_constant) uniform parameter
uint batch_stride_b;
uint batch_stride_d;
uint enable_bias;
uint enable_scale;
uint fusion_flags;
#ifdef MUL_MAT_ID
uint nei0;
@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
#ifdef MUL_MAT_ID
if (p.enable_scale != 0) {
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);

View file

@ -0,0 +1,33 @@
#include "types.glsl"
#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};
layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};
#ifdef MUL_MAT_ID
layout (binding = 5) readonly buffer IDS {int data_ids[];};
#endif

View file

@ -8,14 +8,7 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#include "mul_mat_vec_iface.glsl"
layout (push_constant) uniform parameter
{
@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
uint nb03;
uint nb13;
uint nb23;
uint enable_bias;
uint fusion_flags;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
@ -120,9 +113,12 @@ void main() {
}
if (tid == 0) {
if (p.enable_bias != 0) {
tmp[0] += FLOAT_TYPE(data_bias[idst]);
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
}
dst[idst] = tmp[0];
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
}
data_d[idst] = tmp[0];
}
}

View file

@ -10,14 +10,7 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
#include "mul_mat_vec_iface.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
uint nchannels_y;
uint b_offset;
uint d_offset;
uint enable_bias;
uint fusion_flags;
} p;
#if !USE_SUBGROUP_ADD
@ -151,10 +144,13 @@ void main() {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
if (p.enable_bias != 0) {
temp[c] += FLOAT_TYPE(data_bias[idst]);
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
temp[c] += FLOAT_TYPE(data_fuse0[idst]);
}
dst[idst] = temp[c];
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
temp[c] += FLOAT_TYPE(data_fuse1[idst]);
}
data_d[idst] = temp[c];
}
}
}

View file

@ -819,6 +819,9 @@ void process_shaders() {
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

Binary file not shown.

View file

@ -1686,14 +1686,13 @@ struct server_slot {
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
}
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
bool res = prompt_cache.load(prompt, tokens, ctx, id);
if (!res) {
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
prompt.tokens.clear();
}
return res;
}
std::vector<common_adapter_lora_info> lora;
@ -2339,7 +2338,6 @@ struct server_context {
llama_batch batch {};
bool clean_kv_cache = true;
bool add_bos_token = true;
int32_t n_ctx; // total context for all clients / slots
@ -2702,7 +2700,10 @@ struct server_context {
const int64_t t_start = ggml_time_us();
ret->prompt_save(*prompt_cache);
ret->prompt_load(*prompt_cache, task.tokens);
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
clear_slot(*ret);
}
prompt_cache->update();
@ -2713,12 +2714,21 @@ struct server_context {
return ret;
}
// return true if at least one slot has been purged
void clear_slot(server_slot & slot) const {
GGML_ASSERT(!slot.is_processing());
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
}
// return true if at least one slot has been cleared
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - smarter decision which slot to clear (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool try_clear_idle_slots() {
bool res = false;
if (!params_base.kv_unified) {
@ -2733,12 +2743,11 @@ struct server_context {
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
clear_slot(slot);
res = true;
// purge slots one by one
// clear slots one by one
break;
}
}
@ -2848,14 +2857,6 @@ struct server_context {
return true;
}
void kv_cache_clear() {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
llama_memory_clear(llama_get_memory(ctx), true);
clean_kv_cache = false;
}
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
@ -3443,8 +3444,8 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->prompt.tokens.size();
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
slot->prompt.tokens.clear();
clear_slot(*slot);
auto res = std::make_unique<server_task_result_slot_erase>();
res->id = task.id;
@ -3477,9 +3478,6 @@ struct server_context {
if (all_idle) {
SRV_INF("%s", "all slots are idle\n");
if (clean_kv_cache) {
kv_cache_clear();
}
return;
}
@ -3873,12 +3871,11 @@ struct server_context {
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
clear_slot(slot);
// there is no common part left
slot.n_prompt_tokens_cache = 0;
slot.prompt.tokens.clear();
}
// check if we should process the image
@ -4108,6 +4105,10 @@ struct server_context {
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
// note: it's complicated to keep track of how much of the current batch has been
// processed before the error occurred, so we simply clear the entire context
clear_slot(slot);
}
}
@ -4116,7 +4117,7 @@ struct server_context {
}
// retry with half the batch size to try to find a free slot in the KV cache
if (!try_purge_idle_slots()) {
if (!try_clear_idle_slots()) {
n_batch /= 2;
}

View file

@ -72,12 +72,6 @@
}
}
function handleScroll() {
if (isOpen) {
updateMenuPosition();
}
}
async function handleSelect(value: string | undefined) {
if (!value) return;
@ -259,7 +253,7 @@
}
</script>
<svelte:window onresize={handleResize} onscroll={handleScroll} />
<svelte:window onresize={handleResize} />
<svelte:document onpointerdown={handlePointerDown} onkeydown={handleKeydown} />

View file

@ -2,6 +2,7 @@
import { getDeletionInfo } from '$lib/stores/chat.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import { isIMEComposing } from '$lib/utils/is-ime-composing';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
import ChatMessageUser from './ChatMessageUser.svelte';
@ -54,6 +55,29 @@
return null;
});
let toolCallContent = $derived.by((): ApiChatCompletionToolCall[] | string | null => {
if (message.role === 'assistant') {
const trimmedToolCalls = message.toolCalls?.trim();
if (!trimmedToolCalls) {
return null;
}
try {
const parsed = JSON.parse(trimmedToolCalls);
if (Array.isArray(parsed)) {
return parsed as ApiChatCompletionToolCall[];
}
} catch {
// Harmony-only path: fall back to the raw string so issues surface visibly.
}
return trimmedToolCalls;
}
return null;
});
function handleCancelEdit() {
isEditing = false;
editedContent = message.content;
@ -171,5 +195,6 @@
{showDeleteDialog}
{siblingInfo}
{thinkingContent}
{toolCallContent}
/>
{/if}

View file

@ -11,7 +11,8 @@
Gauge,
Clock,
WholeWord,
ChartNoAxesColumn
ChartNoAxesColumn,
Wrench
} from '@lucide/svelte';
import { Button } from '$lib/components/ui/button';
import { Checkbox } from '$lib/components/ui/checkbox';
@ -21,6 +22,7 @@
import { config } from '$lib/stores/settings.svelte';
import { modelName as serverModelName } from '$lib/stores/server.svelte';
import { copyToClipboard } from '$lib/utils/copy';
import type { ApiChatCompletionToolCall } from '$lib/types/api';
interface Props {
class?: string;
@ -51,6 +53,7 @@
siblingInfo?: ChatMessageSiblingInfo | null;
textareaElement?: HTMLTextAreaElement;
thinkingContent: string | null;
toolCallContent: ApiChatCompletionToolCall[] | string | null;
}
let {
@ -76,9 +79,15 @@
shouldBranchAfterEdit = false,
siblingInfo = null,
textareaElement = $bindable(),
thinkingContent
thinkingContent,
toolCallContent = null
}: Props = $props();
const toolCalls = $derived(
Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null
);
const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null);
const processingState = useProcessingState();
let currentConfig = $derived(config());
let serverModel = $derived(serverModelName());
@ -97,6 +106,58 @@
void copyToClipboard(model ?? '');
}
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
const callNumber = index + 1;
const functionName = toolCall.function?.name?.trim();
const label = functionName || `Call #${callNumber}`;
const payload: Record<string, unknown> = {};
const id = toolCall.id?.trim();
if (id) {
payload.id = id;
}
const type = toolCall.type?.trim();
if (type) {
payload.type = type;
}
if (toolCall.function) {
const fnPayload: Record<string, unknown> = {};
const name = toolCall.function.name?.trim();
if (name) {
fnPayload.name = name;
}
const rawArguments = toolCall.function.arguments?.trim();
if (rawArguments) {
try {
fnPayload.arguments = JSON.parse(rawArguments);
} catch {
fnPayload.arguments = rawArguments;
}
}
if (Object.keys(fnPayload).length > 0) {
payload.function = fnPayload;
}
}
const formattedPayload = JSON.stringify(payload, null, 2);
return {
label,
tooltip: formattedPayload,
copyValue: formattedPayload
};
}
function handleCopyToolCall(payload: string) {
void copyToClipboard(payload, 'Tool call copied to clipboard');
}
</script>
<div
@ -189,6 +250,47 @@
</span>
{/if}
{#if config().showToolCalls}
{#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls}
<span class="inline-flex flex-wrap items-center gap-2 text-xs text-muted-foreground">
<span class="inline-flex items-center gap-1">
<Wrench class="h-3.5 w-3.5" />
<span>Tool calls:</span>
</span>
{#if toolCalls && toolCalls.length > 0}
{#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)}
{@const badge = formatToolCallBadge(toolCall, index)}
<button
type="button"
class="tool-call-badge inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={badge.tooltip}
aria-label={`Copy tool call ${badge.label}`}
onclick={() => handleCopyToolCall(badge.copyValue)}
>
{badge.label}
<Copy class="ml-1 h-3 w-3" />
</button>
{/each}
{:else if fallbackToolCalls}
<button
type="button"
class="tool-call-badge tool-call-badge--fallback inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75"
title={fallbackToolCalls}
aria-label="Copy tool call payload"
onclick={() => handleCopyToolCall(fallbackToolCalls)}
>
{fallbackToolCalls}
<Copy class="ml-1 h-3 w-3" />
</button>
{/if}
</span>
{/if}
{/if}
{#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms}
{@const tokensPerSecond = (message.timings.predicted_n / message.timings.predicted_ms) * 1000}
<span class="inline-flex items-center gap-2 text-xs text-muted-foreground">
@ -287,4 +389,17 @@
white-space: pre-wrap;
word-break: break-word;
}
.tool-call-badge {
max-width: 12rem;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.tool-call-badge--fallback {
max-width: 20rem;
white-space: normal;
word-break: break-word;
}
</style>

View file

@ -76,10 +76,10 @@
});
</script>
<div class="chat-processing-info-container" class:visible={showSlotsInfo}>
<div class="chat-processing-info-container pointer-events-none" class:visible={showSlotsInfo}>
<div class="chat-processing-info-content">
{#each processingDetails as detail (detail)}
<span class="chat-processing-info-detail">{detail}</span>
<span class="chat-processing-info-detail pointer-events-auto">{detail}</span>
{/each}
</div>
</div>
@ -92,7 +92,6 @@
padding: 1.5rem 1rem;
opacity: 0;
transform: translateY(50%);
pointer-events: none;
transition:
opacity 300ms ease-out,
transform 300ms ease-out;
@ -100,7 +99,6 @@
.chat-processing-info-container.visible {
opacity: 1;
pointer-events: auto;
transform: translateY(0);
}

View file

@ -226,6 +226,11 @@
label: 'Enable model selector',
type: 'checkbox'
},
{
key: 'showToolCalls',
label: 'Show tool call labels',
type: 'checkbox'
},
{
key: 'disableReasoningFormat',
label: 'Show raw LLM output',

View file

@ -6,6 +6,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
theme: 'system',
showTokensPerSecond: false,
showThoughtInProgress: false,
showToolCalls: false,
disableReasoningFormat: false,
keepStatsVisible: false,
showMessageStats: true,
@ -80,6 +81,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
custom: 'Custom JSON parameters to send to the API. Must be valid JSON format.',
showTokensPerSecond: 'Display generation speed in tokens per second during streaming.',
showThoughtInProgress: 'Expand thought process by default when generating messages.',
showToolCalls:
'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.',
disableReasoningFormat:
'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.',
keepStatsVisible: 'Keep processing statistics visible after generation finishes.',

View file

@ -1,6 +1,25 @@
import { config } from '$lib/stores/settings.svelte';
import { selectedModelName } from '$lib/stores/models.svelte';
import { slotsService } from './slots';
import type {
ApiChatCompletionRequest,
ApiChatCompletionResponse,
ApiChatCompletionStreamChunk,
ApiChatCompletionToolCall,
ApiChatCompletionToolCallDelta,
ApiChatMessageData
} from '$lib/types/api';
import type {
DatabaseMessage,
DatabaseMessageExtra,
DatabaseMessageExtraAudioFile,
DatabaseMessageExtraImageFile,
DatabaseMessageExtraLegacyContext,
DatabaseMessageExtraPdfFile,
DatabaseMessageExtraTextFile
} from '$lib/types/database';
import type { ChatMessagePromptProgress, ChatMessageTimings } from '$lib/types/chat';
import type { SettingsChatServiceOptions } from '$lib/types/settings';
/**
* ChatService - Low-level API communication layer for llama.cpp server interactions
*
@ -53,6 +72,7 @@ export class ChatService {
onComplete,
onError,
onReasoningChunk,
onToolCallChunk,
onModel,
onFirstValidChunk,
// Generation parameters
@ -201,6 +221,7 @@ export class ChatService {
onComplete,
onError,
onReasoningChunk,
onToolCallChunk,
onModel,
onFirstValidChunk,
conversationId,
@ -208,7 +229,13 @@ export class ChatService {
);
return;
} else {
return this.handleNonStreamResponse(response, onComplete, onError, onModel);
return this.handleNonStreamResponse(
response,
onComplete,
onError,
onToolCallChunk,
onModel
);
}
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
@ -264,10 +291,12 @@ export class ChatService {
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCalls?: string
) => void,
onError?: (error: Error) => void,
onReasoningChunk?: (chunk: string) => void,
onToolCallChunk?: (chunk: string) => void,
onModel?: (model: string) => void,
onFirstValidChunk?: () => void,
conversationId?: string,
@ -282,11 +311,53 @@ export class ChatService {
const decoder = new TextDecoder();
let aggregatedContent = '';
let fullReasoningContent = '';
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
let hasReceivedData = false;
let lastTimings: ChatMessageTimings | undefined;
let streamFinished = false;
let modelEmitted = false;
let firstValidChunkEmitted = false;
let toolCallIndexOffset = 0;
let hasOpenToolCallBatch = false;
const finalizeOpenToolCallBatch = () => {
if (!hasOpenToolCallBatch) {
return;
}
toolCallIndexOffset = aggregatedToolCalls.length;
hasOpenToolCallBatch = false;
};
const processToolCallDelta = (toolCalls?: ApiChatCompletionToolCallDelta[]) => {
if (!toolCalls || toolCalls.length === 0) {
return;
}
aggregatedToolCalls = this.mergeToolCallDeltas(
aggregatedToolCalls,
toolCalls,
toolCallIndexOffset
);
if (aggregatedToolCalls.length === 0) {
return;
}
hasOpenToolCallBatch = true;
const serializedToolCalls = JSON.stringify(aggregatedToolCalls);
if (!serializedToolCalls) {
return;
}
hasReceivedData = true;
if (!abortSignal?.aborted) {
onToolCallChunk?.(serializedToolCalls);
}
};
try {
let chunk = '';
@ -325,6 +396,7 @@ export class ChatService {
const content = parsed.choices[0]?.delta?.content;
const reasoningContent = parsed.choices[0]?.delta?.reasoning_content;
const toolCalls = parsed.choices[0]?.delta?.tool_calls;
const timings = parsed.timings;
const promptProgress = parsed.prompt_progress;
@ -342,6 +414,7 @@ export class ChatService {
}
if (content) {
finalizeOpenToolCallBatch();
hasReceivedData = true;
aggregatedContent += content;
if (!abortSignal?.aborted) {
@ -350,12 +423,15 @@ export class ChatService {
}
if (reasoningContent) {
finalizeOpenToolCallBatch();
hasReceivedData = true;
fullReasoningContent += reasoningContent;
if (!abortSignal?.aborted) {
onReasoningChunk?.(reasoningContent);
}
}
processToolCallDelta(toolCalls);
} catch (e) {
console.error('Error parsing JSON chunk:', e);
}
@ -368,12 +444,26 @@ export class ChatService {
if (abortSignal?.aborted) return;
if (streamFinished) {
if (!hasReceivedData && aggregatedContent.length === 0) {
finalizeOpenToolCallBatch();
if (
!hasReceivedData &&
aggregatedContent.length === 0 &&
aggregatedToolCalls.length === 0
) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(aggregatedContent, fullReasoningContent || undefined, lastTimings);
const finalToolCalls =
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
onComplete?.(
aggregatedContent,
fullReasoningContent || undefined,
lastTimings,
finalToolCalls
);
}
} catch (error) {
const err = error instanceof Error ? error : new Error('Stream error');
@ -386,6 +476,54 @@ export class ChatService {
}
}
private mergeToolCallDeltas(
existing: ApiChatCompletionToolCall[],
deltas: ApiChatCompletionToolCallDelta[],
indexOffset = 0
): ApiChatCompletionToolCall[] {
const result = existing.map((call) => ({
...call,
function: call.function ? { ...call.function } : undefined
}));
for (const delta of deltas) {
const index =
typeof delta.index === 'number' && delta.index >= 0
? delta.index + indexOffset
: result.length;
while (result.length <= index) {
result.push({ function: undefined });
}
const target = result[index]!;
if (delta.id) {
target.id = delta.id;
}
if (delta.type) {
target.type = delta.type;
}
if (delta.function) {
const fn = target.function ? { ...target.function } : {};
if (delta.function.name) {
fn.name = delta.function.name;
}
if (delta.function.arguments) {
fn.arguments = (fn.arguments ?? '') + delta.function.arguments;
}
target.function = fn;
}
}
return result;
}
/**
* Handles non-streaming response from the chat completion API.
* Parses the JSON response and extracts the generated content.
@ -401,9 +539,11 @@ export class ChatService {
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCalls?: string
) => void,
onError?: (error: Error) => void,
onToolCallChunk?: (chunk: string) => void,
onModel?: (model: string) => void
): Promise<string> {
try {
@ -423,17 +563,31 @@ export class ChatService {
const content = data.choices[0]?.message?.content || '';
const reasoningContent = data.choices[0]?.message?.reasoning_content;
const toolCalls = data.choices[0]?.message?.tool_calls;
if (reasoningContent) {
console.log('Full reasoning content:', reasoningContent);
}
if (!content.trim()) {
let serializedToolCalls: string | undefined;
if (toolCalls && toolCalls.length > 0) {
const mergedToolCalls = this.mergeToolCallDeltas([], toolCalls);
if (mergedToolCalls.length > 0) {
serializedToolCalls = JSON.stringify(mergedToolCalls);
if (serializedToolCalls) {
onToolCallChunk?.(serializedToolCalls);
}
}
}
if (!content.trim() && !serializedToolCalls) {
const noResponseError = new Error('No response received from server. Please try again.');
throw noResponseError;
}
onComplete?.(content, reasoningContent);
onComplete?.(content, reasoningContent, undefined, serializedToolCalls);
return content;
} catch (error) {

View file

@ -205,6 +205,7 @@ class ChatStore {
type,
timestamp: Date.now(),
thinking: '',
toolCalls: '',
children: [],
extra: extras
},
@ -360,6 +361,7 @@ class ChatStore {
): Promise<void> {
let streamedContent = '';
let streamedReasoningContent = '';
let streamedToolCallContent = '';
let resolvedModel: string | null = null;
let modelPersisted = false;
@ -468,6 +470,20 @@ class ChatStore {
this.updateMessageAtIndex(messageIndex, { thinking: streamedReasoningContent });
},
onToolCallChunk: (toolCallChunk: string) => {
const chunk = toolCallChunk.trim();
if (!chunk) {
return;
}
streamedToolCallContent = chunk;
const messageIndex = this.findMessageIndex(assistantMessage.id);
this.updateMessageAtIndex(messageIndex, { toolCalls: streamedToolCallContent });
},
onModel: (modelName: string) => {
recordModel(modelName);
},
@ -475,18 +491,21 @@ class ChatStore {
onComplete: async (
finalContent?: string,
reasoningContent?: string,
timings?: ChatMessageTimings
timings?: ChatMessageTimings,
toolCallContent?: string
) => {
slotsService.stopStreaming();
const updateData: {
content: string;
thinking: string;
toolCalls: string;
timings?: ChatMessageTimings;
model?: string;
} = {
content: finalContent || streamedContent,
thinking: reasoningContent || streamedReasoningContent,
toolCalls: toolCallContent || streamedToolCallContent,
timings: timings
};
@ -499,7 +518,11 @@ class ChatStore {
const messageIndex = this.findMessageIndex(assistantMessage.id);
const localUpdateData: { timings?: ChatMessageTimings; model?: string } = {
const localUpdateData: {
timings?: ChatMessageTimings;
model?: string;
toolCalls?: string;
} = {
timings: timings
};
@ -507,6 +530,10 @@ class ChatStore {
localUpdateData.model = updateData.model;
}
if (updateData.toolCalls !== undefined) {
localUpdateData.toolCalls = updateData.toolCalls;
}
this.updateMessageAtIndex(messageIndex, localUpdateData);
await DatabaseStore.updateCurrentNode(assistantMessage.convId, assistantMessage.id);
@ -620,6 +647,7 @@ class ChatStore {
content: '',
timestamp: Date.now(),
thinking: '',
toolCalls: '',
children: [],
model: null
},
@ -1443,6 +1471,7 @@ class ChatStore {
role: messageToEdit.role,
content: newContent,
thinking: messageToEdit.thinking || '',
toolCalls: messageToEdit.toolCalls || '',
children: [],
model: messageToEdit.model // Preserve original model info when branching
},
@ -1518,6 +1547,7 @@ class ChatStore {
role: messageToEdit.role,
content: newContent,
thinking: messageToEdit.thinking || '',
toolCalls: messageToEdit.toolCalls || '',
children: [],
extra: messageToEdit.extra ? JSON.parse(JSON.stringify(messageToEdit.extra)) : undefined,
model: messageToEdit.model // Preserve original model info when branching
@ -1589,6 +1619,7 @@ class ChatStore {
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},
@ -1647,6 +1678,7 @@ class ChatStore {
role: 'assistant',
content: '',
thinking: '',
toolCalls: '',
children: [],
model: null
},

View file

@ -114,6 +114,7 @@ export class DatabaseStore {
...message,
id: uuid(),
parent: parentId,
toolCalls: message.toolCalls ?? '',
children: []
};
@ -154,6 +155,7 @@ export class DatabaseStore {
content: '',
parent: null,
thinking: '',
toolCalls: '',
children: []
};

View file

@ -183,6 +183,23 @@ export interface ApiChatCompletionRequest {
samplers?: string[];
// Custom parameters (JSON string)
custom?: Record<string, unknown>;
timings_per_token?: boolean;
}
export interface ApiChatCompletionToolCallFunctionDelta {
name?: string;
arguments?: string;
}
export interface ApiChatCompletionToolCallDelta {
index?: number;
id?: string;
type?: string;
function?: ApiChatCompletionToolCallFunctionDelta;
}
export interface ApiChatCompletionToolCall extends ApiChatCompletionToolCallDelta {
function?: ApiChatCompletionToolCallFunctionDelta & { arguments?: string };
}
export interface ApiChatCompletionStreamChunk {
@ -195,6 +212,7 @@ export interface ApiChatCompletionStreamChunk {
content?: string;
reasoning_content?: string;
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
};
}>;
timings?: {
@ -216,6 +234,7 @@ export interface ApiChatCompletionResponse {
content: string;
reasoning_content?: string;
model?: string;
tool_calls?: ApiChatCompletionToolCallDelta[];
};
}>;
}

View file

@ -60,6 +60,7 @@ export interface DatabaseMessage {
content: string;
parent: string;
thinking: string;
toolCalls?: string;
children: string[];
extra?: DatabaseMessageExtra[];
timings?: ChatMessageTimings;

View file

@ -38,12 +38,19 @@ export interface SettingsChatServiceOptions {
samplers?: string | string[];
// Custom parameters
custom?: string;
timings_per_token?: boolean;
// Callbacks
onChunk?: (chunk: string) => void;
onReasoningChunk?: (chunk: string) => void;
onToolCallChunk?: (chunk: string) => void;
onModel?: (model: string) => void;
onFirstValidChunk?: () => void;
onComplete?: (response: string, reasoningContent?: string, timings?: ChatMessageTimings) => void;
onComplete?: (
response: string,
reasoningContent?: string,
timings?: ChatMessageTimings,
toolCalls?: string
) => void;
onError?: (error: Error) => void;
}