From 45363632cbd593537d541e81b600242e0b3d47fc Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 22 Aug 2025 11:28:03 -0700 Subject: [PATCH 1/7] ggml WebGPU: add support for quantization types (#15440) * Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments * Work on templating for different types in shaders * Work on shader type generation * Working q4_0 mul_mat and some templating for different types * Add q4_0_f16 matmul and fix device init * Add matmul support for basic quantization types * Add q2_k and q3_k quantization * Add rest of k-quants * Get firt i-quant working * Closer to supporting all i-quants * Support rest of i-quants * Cleanup code * Fix python formatting * debug * Bugfix for memset * Add padding to end of buffers on creation * Simplify bit-shifting * Update usage of StringView --- ggml/src/ggml-webgpu/CMakeLists.txt | 4 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 426 ++-- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 90 +- ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl | 16 +- .../wgsl-shaders/mul_mat.tmpl.wgsl | 1794 +++++++++++++++++ .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 56 - 6 files changed, 2143 insertions(+), 243 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 79ef68b85..78a985a4d 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -20,8 +20,8 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR} COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8 ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py - --input "${SHADER_DIR}" - --output "${SHADER_HEADER}" + --input_dir "${SHADER_DIR}" + --output_file "${SHADER_HEADER}" DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py VERBATIM ) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ba1addc8d..32f1e304e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -118,13 +118,11 @@ struct webgpu_context_struct { std::recursive_mutex mutex; - bool device_init = false; - webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; wgpu::ComputePipeline memset_pipeline; - wgpu::ComputePipeline mul_mat_pipeline; + wgpu::ComputePipeline mul_mat_pipeline[30][2]; wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline cpy_pipeline; @@ -238,7 +236,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } }), UINT64_MAX); @@ -278,7 +276,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { wgpu::CallbackMode::AllowSpontaneous, [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } // Free the staged buffers ctx->param_buf_pool.free_bufs(staged_param_bufs); @@ -294,7 +292,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { wgpu::CallbackMode::AllowSpontaneous, [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data); + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); } else { const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); if (*error_data) { @@ -331,6 +329,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { + ggml_backend_webgpu_submit_queue(ctx); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -421,15 +420,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); } -static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - /** End WebGPU Actions */ /** GGML Backend Interface */ @@ -447,19 +437,36 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { GGML_UNUSED(ctx); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); +} + static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - size_t src_offset = ggml_backend_webgpu_tensor_offset(src); - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - uint32_t ne = (uint32_t) ggml_nelements(dst); + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { ne, - (uint32_t) (src_misalignment / ggml_type_size(src->type)), - (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // Convert byte-strides to element-strides (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), @@ -477,15 +484,13 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src), - .offset = src_offset, - .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = dst_offset, - .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; @@ -504,21 +509,9 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t error_bufs.host_buf.Unmap(); } - size_t src_offset = ggml_backend_webgpu_tensor_offset(src); - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx); - size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - - std::vector params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)), - (uint32_t) (idx_misalignment / ggml_type_size(idx->type)), - (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // Convert byte-strides to element-strides (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), @@ -540,18 +533,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src), - .offset = ggml_backend_webgpu_tensor_offset(src), - .size = ggml_nbytes(src) }, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(idx), - .offset = ggml_backend_webgpu_tensor_offset(idx), - .size = ggml_nbytes(idx) }, + .buffer = ggml_webgpu_tensor_buf(idx), + .offset = ggml_webgpu_tensor_align_offset(ctx, idx), + .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, { .binding = 2, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = ggml_backend_webgpu_tensor_offset(dst), - .size = ggml_nbytes(dst) }, - { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; @@ -565,15 +558,18 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) dst->ne[1], // number of rows in result (M) (uint32_t) dst->ne[0], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3 + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 (uint32_t) src0->ne[2], // batch size in dimension 2 (uint32_t) src0->ne[3], // batch size in dimension 3 (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 @@ -582,22 +578,22 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t std::vector entries = { { .binding = 0, - .buffer = ggml_backend_webgpu_tensor_buf(src0), - .offset = ggml_backend_webgpu_tensor_offset(src0), - .size = ggml_nbytes(src0) }, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, { .binding = 1, - .buffer = ggml_backend_webgpu_tensor_buf(src1), - .offset = ggml_backend_webgpu_tensor_offset(src1), - .size = ggml_nbytes(src1) }, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, { .binding = 2, - .buffer = ggml_backend_webgpu_tensor_buf(dst), - .offset = ggml_backend_webgpu_tensor_offset(dst), - .size = ggml_nbytes(dst) } + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } // Returns true if node has enqueued work into the queue, false otherwise @@ -827,7 +823,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b wgpu::Buffer buf; ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, - size, + (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); @@ -907,7 +903,94 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_mul_mat_f32_f32, + "mul_mat_f32_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_mul_mat_f16_f16, + "mul_mat_f16_f16"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_mul_mat_f16_f32, + "mul_mat_f16_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], + wgsl_mul_mat_q4_0_f32, + "mul_mat_q4_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], + wgsl_mul_mat_q4_1_f32, + "mul_mat_q4_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32], + wgsl_mul_mat_q5_0_f32, + "mul_mat_q5_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32], + wgsl_mul_mat_q5_1_f32, + "mul_mat_q5_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32], + wgsl_mul_mat_q8_0_f32, + "mul_mat_q8_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32], + wgsl_mul_mat_q2_k_f32, + "mul_mat_q2_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32], + wgsl_mul_mat_q3_k_f32, + "mul_mat_q3_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32], + wgsl_mul_mat_q4_k_f32, + "mul_mat_q4_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32], + wgsl_mul_mat_q5_k_f32, + "mul_mat_q5_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32], + wgsl_mul_mat_q6_k_f32, + "mul_mat_q6_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32], + wgsl_mul_mat_iq2_xxs_f32, + "mul_mat_iq2_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32], + wgsl_mul_mat_iq2_xs_f32, + "mul_mat_iq2_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32], + wgsl_mul_mat_iq2_s_f32, + "mul_mat_iq2_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32], + wgsl_mul_mat_iq3_xxs_f32, + "mul_mat_iq3_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32], + wgsl_mul_mat_iq3_s_f32, + "mul_mat_iq3_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32], + wgsl_mul_mat_iq1_s_f32, + "mul_mat_iq1_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32], + wgsl_mul_mat_iq1_m_f32, + "mul_mat_iq1_m_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32], + wgsl_mul_mat_iq4_nl_f32, + "mul_mat_iq4_nl_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], + wgsl_mul_mat_iq4_xs_f32, + "mul_mat_iq4_xs_f32"); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -933,79 +1016,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; - // Multiple threads may try to initialize the device - std::lock_guard lock(webgpu_ctx->mutex); - if (!webgpu_ctx->device_init) { - // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &webgpu_ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_LOG_ERROR( - "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_LOG_ERROR( - "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - webgpu_ctx->instance.WaitAny( - webgpu_ctx->adapter.RequestDevice( - &dev_desc, - wgpu::CallbackMode::AllowSpontaneous, - [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); - return; - } - webgpu_ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(webgpu_ctx->device != nullptr); - - // Initialize (compute) queue - webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - webgpu_ctx->param_buf_pool.init(webgpu_ctx->device, - WEBGPU_NUM_PARAM_BUFS, - WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device, - WEBGPU_NUM_SET_ROWS_ERROR_BUFS, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(webgpu_ctx); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_set_rows_pipeline(webgpu_ctx); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(webgpu_ctx->device, - webgpu_ctx->debug_host_buf, - WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, - "debug_host_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, - webgpu_ctx->debug_dev_buf, - WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, - "debug_dev_buf"); -#endif - webgpu_ctx->device_init = true; - } - static ggml_backend_webgpu_context backend_ctx; backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; backend_ctx.webgpu_ctx = webgpu_ctx; @@ -1053,10 +1063,45 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_VIEW: case GGML_OP_PERMUTE: return true; - case GGML_OP_CPY | GGML_OP_SET_ROWS: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_MUL_MAT: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + { + switch (op->src[1]->type) { + case GGML_TYPE_F16: + return op->src[0]->type == GGML_TYPE_F16; + case GGML_TYPE_F32: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return true; + default: + return false; + } + default: + return false; + } + } default: return false; } @@ -1123,20 +1168,87 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::AdapterInfo info{}; ctx->adapter.GetInfo(&info); + // Initialize device + std::vector required_features = { wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::ImplicitDeviceSynchronization }; + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR( + "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR( + "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); + }); + ctx->instance.WaitAny(ctx->adapter.RequestDevice( + &dev_desc, + wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); + return; + } + ctx->device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->device != nullptr); + + // Initialize (compute) queue + ctx->queue = ctx->device.GetQueue(); + + // Create buffer pool for shader parameters + ctx->param_buf_pool.init(ctx->device, + WEBGPU_NUM_PARAM_BUFS, + WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ctx->set_rows_error_buf_pool.init(ctx->device, + WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + + ggml_webgpu_init_memset_pipeline(ctx); + ggml_webgpu_init_mul_mat_pipeline(ctx); + ggml_webgpu_init_set_rows_pipeline(ctx); + ggml_webgpu_init_cpy_pipeline(ctx); + +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(ctx->device, + ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + "debug_host_buf"); + ggml_webgpu_create_buffer(ctx->device, + ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, + "debug_dev_buf"); +#endif + static ggml_backend_webgpu_device_context device_ctx; device_ctx.webgpu_ctx = ctx; device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = std::string(info.description.data); + device_ctx.device_desc = info.description; GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", info.vendorID, - info.vendor.data, - info.architecture.data, + std::string(info.vendor).c_str(), + std::string(info.architecture).c_str(), info.deviceID, - info.device.data, - info.description.data); + std::string(info.device).c_str(), + std::string(info.description).c_str()); // See GGML Backend Device Interface section static ggml_backend_device device = { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 962dcd6b1..cc8def7f1 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,35 +1,85 @@ import os +import re +import ast import argparse -def escape_triple_quotes(wgsl): - # Simple defense in case of embedded """ - return wgsl.replace('"""', '\\"""') +def extract_block(text, name): + pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' + match = re.search(pattern, text, re.DOTALL) + if not match: + raise ValueError(f"Missing block: {name}") + return match.group(1).strip() -def to_cpp_string_literal(varname, content): - return f'const char* wgsl_{varname} = R"({content})";\n' +def parse_decls(decls_text): + decls = {} + for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): + decls[name.strip()] = code.strip() + return decls + + +def replace_placeholders(shader_text, replacements): + for key, val in replacements.items(): + # Match {{KEY}} literally, where KEY is escaped + pattern = r'{{\s*' + re.escape(key) + r'\s*}}' + shader_text = re.sub(pattern, str(val), shader_text) + return shader_text + + +def write_shader(shader_name, shader_code, output_dir, outfile): + if output_dir: + wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") + with open(wgsl_filename, "w", encoding="utf-8") as f_out: + f_out.write(shader_code) + outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') + + +def generate_variants(shader_path, output_dir, outfile): + shader_base_name = shader_path.split("/")[-1].split(".")[0] + + with open(shader_path, "r", encoding="utf-8") as f: + text = f.read() + + try: + variants = ast.literal_eval(extract_block(text, "VARIANTS")) + except ValueError: + write_shader(shader_base_name, text, output_dir, outfile) + else: + decls_map = parse_decls(extract_block(text, "DECLS")) + shader_template = extract_block(text, "SHADER") + + for variant in variants: + decls = variant["DECLS"] + decls_code = "" + for key in decls: + if key not in decls_map: + raise ValueError(f"DECLS key '{key}' not found.") + decls_code += decls_map[key] + "\n\n" + + shader_variant = replace_placeholders(shader_template, variant["REPLS"]) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) + write_shader(output_name, final_shader, output_dir, outfile) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True) - parser.add_argument('--output', required=True) + parser.add_argument("--input_dir", required=True) + parser.add_argument("--output_file", required=True) + parser.add_argument("--output_dir") args = parser.parse_args() - with open(args.output, 'w', encoding='utf-8') as out: - out.write("// Auto-generated shader embedding \n\n") - for fname in sorted(os.listdir(args.input)): - if not fname.endswith('.wgsl'): - continue - shader_path = os.path.join(args.input, fname) - varname = os.path.splitext(fname)[0] - with open(shader_path, 'r', encoding='utf-8') as f: - content = f.read() - content = escape_triple_quotes(content) - out.write(to_cpp_string_literal(varname, content)) - out.write('\n') + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.output_file, "w", encoding="utf-8") as out: + out.write("// Auto-generated shader embedding\n\n") + for fname in sorted(os.listdir(args.input_dir)): + if fname.endswith(".wgsl"): + generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl index cb7c8c3e0..194d2d6f5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl @@ -19,20 +19,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let start = params.offset; let end = params.offset + params.size; - for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) { + for (var j: u32 = 0u; j < bytes_per_thread; j += 4) { let byte_index = start + i + j; - if (byte_index + 4u <= end) { - output_buffer[(byte_index >> 2u)] = params.value; + if (byte_index + 4 <= end) { + output_buffer[byte_index >> 2] = params.value; } else { // Handle tail (unaligned) - for (var k: u32 = 0u; k < 4u; k = k + 1u) { + for (var k: u32 = 0; k < 4; k++) { let idx = byte_index + k; if (idx < end) { - let word_idx = idx >> 2u; - let byte_offset = (idx & 3u) * 8u; - let mask = ~(0xffu << byte_offset); + let word_idx = idx >> 2; + let bit_offset = (idx & 3) * 8u; + let mask = ~(0xffu << bit_offset); let existing = output_buffer[word_idx]; - output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset); + output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset)); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl new file mode 100644 index 000000000..79465c298 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -0,0 +1,1794 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "BLOCK_SIZE" : 1 + }, + "DECLS" : ["FLOAT"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_1", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q4_1"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_1", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q5_1"] + }, + { + "REPLS": { + "SRC0_TYPE": "q8_0", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32 + }, + "DECLS": ["BYTE_HELPERS", "Q8_0"] + }, + { + "REPLS": { + "SRC0_TYPE": "q2_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q2_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q3_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q3_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q4_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q5_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "q6_k", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "Q6_K"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_xxs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_xs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq2_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq3_xxs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq3_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq1_s", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_TABLE","IQ1_S"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq1_m", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256 + }, + "DECLS": ["BYTE_HELPERS", "IQ1_TABLE","IQ1_M"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq4_nl", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 32, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_TABLE", "IQ4_NL"] + }, + { + "REPLS": { + "SRC0_TYPE": "iq4_xs", + "SRC1_TYPE": "f32", + "BLOCK_SIZE": 256, + }, + "DECLS": ["BYTE_HELPERS", "IQ4_TABLE", "IQ4_XS"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(BYTE_HELPERS) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +#enddecl(BYTE_HELPERS) + +#decl(FLOAT) +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); +} +#enddecl(FLOAT) + +#decl(Q4_0) +struct q4_0 { + d: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_0 = src0[src0_idx_base + offset]; + let d = f32(block_q4_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q4_0) + +#decl(Q4_1) +struct q4_1 { + d: f16, + m: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q4_1 = src0[src0_idx_base + offset]; + let d = f32(block_q4_1.d); + let m = f32(block_q4_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q4_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q4_1) + +#decl(Q5_0) +struct q5_0 { + d: f16, + qh: array, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_0 = src0[src0_idx_base + offset]; + let d = f32(block_q5_0.d); + var sum: f32 = 0.0; + let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + for (var j: u32 = 0; j < 4; j++) { + let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q5_0) + +#decl(Q5_1) +struct q5_1 { + d: f16, + m: f16, + qh: u32, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q5_1 = src0[src0_idx_base + offset]; + let d = f32(block_q5_1.d); + let m = f32(block_q5_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 4; j++) { + let q_packed = block_q5_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_lo * f32(src1[src1_offset]); + sum += q_hi * f32(src1[src1_offset + 16]); + } + } + return sum; +} +#enddecl(Q5_1) + +#decl(Q8_0) +struct q8_0 { + d: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_0 = src0[src0_idx_base + offset]; + let d = f32(block_q8_0.d); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#enddecl(Q8_0) + +#decl(Q8_1) +struct q8_1 { + d: f16, + m: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block_q8_1 = src0[src0_idx_base + offset]; + let d = f32(block_q8_1.d); + let m = f32(block_q8_1.m); + var sum: f32 = 0.0; + for (var j: u32 = 0; j < 8; j++) { + let q_packed = block_q8_1.qs[j]; + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + m; + let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; + sum += q_val * f32(src1[src1_offset]); + } + } + return sum; +} +#enddecl(Q8_1) + +#decl(Q2_K) +// 16 blocks of 16 elements each +struct q2_k { + scales: array, + qs: array, + d: f16, + dmin: f16 +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(block.scales[is / 4], is % 4); + is++; + let dl = d * f32(sc & 0xF); + let ml = m * f32(sc >> 4); + for (var l: u32 = 0u; l < 16; l++) { + let q_idx = q_b_idx + k + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 3; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + } + return sum; +} + +#enddecl(Q2_K) + +#decl(Q3_K) +// 16 blocks of 16 elements each +struct q3_k { + hmask: array, + qs: array, + scales: array, // 6-bit quantized values + d: f16 +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, + // and 2-bits from the last 4 bytes + let kmask1: u32 = 0x03030303; + let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + // convert arrays of f16 -> u32 + var hmask_vals: array; + for (var i: u32 = 0; i < 8; i++) { + hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + } + var qs_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var m: u32 = 1; + // 2 halves of the block (128 elements each) + for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { + // 4 groups (each group has 2 blocks of 16 elements) + for (var shift: u32 = 0; shift < 8; shift += 2) { + // 2 blocks + for (var k: u32 = 0; k < 32; k += 16) { + let sc = get_byte(scale_vals[is / 4], is % 4); + is++; + let dl = d * (f32(sc) - 32.0); + for (var l: u32 = 0u; l < 16u; l++) { + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3; + sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; + src1_i++; + } + } + m <<= 1; + } + } + return sum; +} + +#enddecl(Q3_K) + +#decl(Q45_K_SCALE_MIN) + +fn get_scale_min(is: u32, scales: array) -> vec2 { + if (is < 4) { + let sc_byte = get_byte(scales[is / 4], is % 4); + let min_byte = get_byte(scales[(is + 4) / 4], is % 4); + return vec2(f32(sc_byte & 63), f32(min_byte & 63)); + } else { + let sc_min_lo = get_byte(scales[(is + 4) / 4], (is + 4) % 4); + let sc_hi = get_byte(scales[(is - 4) / 4], (is - 4) % 4); + let min_hi = get_byte(scales[is / 4], is % 4); + let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4); + let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4); + return vec2(f32(sc), f32(m)); + } +} + +#enddecl(Q45_K_SCALE_MIN) + +#decl(Q4_K) +// 8 blocks of 32 elements each +struct q4_k { + d: f16, + dmin: f16, + scales: array, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qs_val = (q_byte >> shift) & 0xF; + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(Q4_K) + +#decl(Q5_K) +// 8 blocks of 32 elements each +struct q5_k { + d: f16, + dmin: f16, + scales: array, + qh: array, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let m = f32(block.dmin); + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var is: u32 = 0; + var u: u32 = 1; + // 2 blocks each iteration + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { + for (var shift: u32 = 0; shift < 8; shift += 4) { + let scale_min = get_scale_min(is, block.scales); + is++; + let dl = d * scale_min.x; + let ml = m * scale_min.y; + for (var l: u32 = 0; l < 32; l++) { + let q_idx = q_b_idx + l; + let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); + let qh_byte = get_byte(block.qh[l / 4], l % 4); + let qs_val = (q_byte >> shift) & 0xF; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; + src1_i++; + } + u <<= 1; + } + } + return sum; +} + +#enddecl(Q5_K) + +#decl(Q6_K) +// 16 blocks of 16 elements each +struct q6_k { + ql: array, + qh: array, + scales: array, + d: f16 +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + + // convert arrays of f16 -> u32 + var ql_vals: array; + for (var i: u32 = 0; i < 32; i++) { + ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + } + var qh_vals: array; + for (var i: u32 = 0; i < 16; i++) { + qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + } + var scale_vals: array; + for (var i: u32 = 0; i < 4; i++) { + scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + } + + var sum = 0.0; + var src1_i = src1_idx_base + offset * 256; + var qh_b_idx: u32 = 0; + var sc_b_idx: u32 = 0; + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { + for (var l: u32 = 0; l < 32; l++) { + let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); + let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); + let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); + + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; + + let is = l/16; + let is1 = sc_b_idx + is; + let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); + let is2 = sc_b_idx + is + 2; + let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); + let is3 = sc_b_idx + is + 4; + let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); + let is4 = sc_b_idx + is + 6; + let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); + + sum += d * f32(sc1) * q1 * src1[src1_i + l]; + sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; + sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; + sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; + } + src1_i += 128; + qh_b_idx += 32; + sc_b_idx += 8; + } + return sum; +} + +#enddecl(Q6_K) + +#decl(IQ23_TABLES) +const kmask_iq2xs : array = array( + 0x08040201u, // 1, 2, 4, 8 + 0x80402010u // 16, 32, 64, 128 +); + +const ksigns_iq2xs: array = array( + 0x03828100,0x87060584,0x8b0a0988,0x0f8e8d0c, + 0x93121190,0x17969514,0x1b9a9918,0x9f1e1d9c, + 0xa32221a0,0x27a6a524,0x2baaa928,0xaf2e2dac, + 0x33b2b130,0xb73635b4,0xbb3a39b8,0x3fbebd3c, + 0xc34241c0,0x47c6c544,0x4bcac948,0xcf4e4dcc, + 0x53d2d150,0xd75655d4,0xdb5a59d8,0x5fdedd5c, + 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, + 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc +); +#enddecl(IQ23_TABLES) + +#decl(IQ2_XXS) + +const iq2xxs_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, + 0x082b082b, 0x08080808, 0x082b2b08, 0x08080808, 0x082b2b2b, 0x08080808, 0x19080819, 0x08080808, + 0x19081908, 0x08080808, 0x19190808, 0x08080808, 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, + 0x192b1908, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b082b2b, 0x08080808, + 0x2b2b082b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, 0x08190808, 0x08080819, + 0x08191919, 0x08080819, 0x19080808, 0x08080819, 0x2b081908, 0x08080819, 0x2b192b08, 0x08080819, + 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x082b082b, 0x0808082b, 0x2b08082b, 0x0808082b, + 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x08190808, 0x08081908, 0x082b0819, 0x08081908, + 0x082b1908, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19082b08, 0x08081908, + 0x192b0808, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, + 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, 0x08082b08, 0x08081919, + 0x082b0808, 0x08081919, 0x1908192b, 0x08081919, 0x192b2b19, 0x08081919, 0x2b080808, 0x08081919, + 0x2b190819, 0x08081919, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, 0x19080808, 0x0808192b, + 0x2b081908, 0x0808192b, 0x2b2b1908, 0x0808192b, 0x08080808, 0x08082b08, 0x08081919, 0x08082b08, + 0x08082b08, 0x08082b08, 0x08191908, 0x08082b08, 0x082b2b08, 0x08082b08, 0x19080819, 0x08082b08, + 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x2b082b08, 0x08082b08, + 0x08081908, 0x08082b19, 0x19080808, 0x08082b19, 0x0808082b, 0x08082b2b, 0x08191908, 0x08082b2b, + 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x08190808, 0x08190808, 0x082b0819, 0x08190808, + 0x19080808, 0x08190808, 0x192b0808, 0x08190808, 0x2b081908, 0x08190808, 0x2b190808, 0x08190808, + 0x2b191919, 0x08190808, 0x08080808, 0x08190819, 0x08082b08, 0x08190819, 0x082b0808, 0x08190819, + 0x19190808, 0x08190819, 0x19192b2b, 0x08190819, 0x2b080808, 0x08190819, 0x082b1908, 0x0819082b, + 0x19081919, 0x0819082b, 0x08080808, 0x08191908, 0x08082b08, 0x08191908, 0x082b0808, 0x08191908, + 0x082b1919, 0x08191908, 0x19082b19, 0x08191908, 0x2b080808, 0x08191908, 0x08192b08, 0x08191919, + 0x192b082b, 0x08191919, 0x08080808, 0x0819192b, 0x0819192b, 0x0819192b, 0x08080819, 0x08192b08, + 0x08081908, 0x08192b08, 0x08190808, 0x08192b08, 0x19080808, 0x08192b08, 0x2b080819, 0x08192b08, + 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x2b2b0808, 0x08192b19, 0x19190819, 0x08192b2b, + 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08082b2b, 0x082b0808, 0x19081908, 0x082b0808, + 0x192b0819, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b08082b, 0x082b0808, 0x082b2b19, 0x082b0819, + 0x19082b08, 0x082b0819, 0x08080808, 0x082b082b, 0x0808082b, 0x082b082b, 0x08080819, 0x082b1908, + 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x19080808, 0x082b1908, 0x1919192b, 0x082b1908, + 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x192b1908, 0x082b1919, 0x2b190808, 0x082b192b, + 0x08082b08, 0x082b2b08, 0x082b0808, 0x082b2b08, 0x2b191908, 0x082b2b08, 0x19081908, 0x082b2b2b, + 0x08080819, 0x19080808, 0x08081908, 0x19080808, 0x08190808, 0x19080808, 0x08192b08, 0x19080808, + 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x19080808, 0x19080808, 0x19082b08, 0x19080808, + 0x1919192b, 0x19080808, 0x192b0808, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, + 0x2b190808, 0x19080808, 0x08080808, 0x19080819, 0x082b0808, 0x19080819, 0x192b0819, 0x19080819, + 0x2b080808, 0x19080819, 0x2b081919, 0x19080819, 0x08080819, 0x1908082b, 0x08190808, 0x1908082b, + 0x19082b08, 0x1908082b, 0x1919192b, 0x1908082b, 0x192b2b08, 0x1908082b, 0x08080808, 0x19081908, + 0x08082b08, 0x19081908, 0x082b0808, 0x19081908, 0x2b080808, 0x19081908, 0x2b192b19, 0x19081908, + 0x0819082b, 0x19081919, 0x082b1908, 0x19081919, 0x08080808, 0x1908192b, 0x08080819, 0x19082b08, + 0x08081908, 0x19082b08, 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, + 0x08080808, 0x19082b19, 0x19192b08, 0x19082b19, 0x192b0819, 0x19082b19, 0x2b08082b, 0x19082b19, + 0x19081919, 0x19082b2b, 0x2b190808, 0x19082b2b, 0x08080808, 0x19190808, 0x08082b08, 0x19190808, + 0x08190819, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x2b080808, 0x19190808, + 0x2b082b08, 0x19190808, 0x08081908, 0x19190819, 0x1908082b, 0x19190819, 0x2b2b1908, 0x19190819, + 0x2b190819, 0x1919082b, 0x2b190808, 0x19191908, 0x2b19082b, 0x19191908, 0x08082b2b, 0x19191919, + 0x08080819, 0x1919192b, 0x19191908, 0x1919192b, 0x08080808, 0x19192b08, 0x08190819, 0x19192b08, + 0x08192b19, 0x19192b08, 0x192b1908, 0x19192b08, 0x19080808, 0x19192b19, 0x08082b08, 0x19192b2b, + 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, 0x192b2b08, 0x192b0808, + 0x08080808, 0x192b0819, 0x19191919, 0x192b0819, 0x08192b08, 0x192b082b, 0x192b0808, 0x192b082b, + 0x08080808, 0x192b1908, 0x08081919, 0x192b1908, 0x08190808, 0x192b1919, 0x0819082b, 0x192b1919, + 0x2b081908, 0x192b1919, 0x1908082b, 0x192b2b08, 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, + 0x08082b2b, 0x2b080808, 0x19080819, 0x2b080808, 0x2b08082b, 0x2b080808, 0x08081908, 0x2b080819, + 0x08192b08, 0x2b080819, 0x19080808, 0x2b080819, 0x08190819, 0x2b08082b, 0x08080819, 0x2b081908, + 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x08191919, 0x2b081908, 0x19080808, 0x2b081908, + 0x192b0808, 0x2b081908, 0x08080808, 0x2b081919, 0x1908192b, 0x2b081919, 0x2b191908, 0x2b081919, + 0x08082b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x192b0808, 0x2b08192b, 0x0808082b, 0x2b082b08, + 0x08081908, 0x2b082b19, 0x08190819, 0x2b082b2b, 0x08081908, 0x2b190808, 0x08190808, 0x2b190808, + 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, 0x2b2b0819, 0x2b190808, 0x0819192b, 0x2b190819, + 0x2b080808, 0x2b190819, 0x19081919, 0x2b19082b, 0x08080808, 0x2b191908, 0x082b082b, 0x2b191908, + 0x19081908, 0x2b191908, 0x19190819, 0x2b191919, 0x2b080819, 0x2b192b08, 0x082b0808, 0x2b192b19, + 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, + 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 +); + +struct iq2_xxs { + d: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); + let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; + for (var l: u32 = 0; l < 4; l++) { + let ig = get_byte(aux0, l) * 8; + let is = (aux1 >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += db * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ2_XXS) + +#decl(IQ2_XS) +const iq2xs_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, + 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808, + 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808, + 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808, + 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x2b080808, 0x08080808, + 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, + 0x2b191908, 0x08080808, 0x2b192b19, 0x08080808, 0x2b2b0808, 0x08080808, 0x08080819, 0x08080819, + 0x08081908, 0x08080819, 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, + 0x0819082b, 0x08080819, 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x08192b2b, 0x08080819, + 0x082b0819, 0x08080819, 0x082b1908, 0x08080819, 0x19080808, 0x08080819, 0x1908082b, 0x08080819, + 0x19081919, 0x08080819, 0x19082b08, 0x08080819, 0x19190819, 0x08080819, 0x19191908, 0x08080819, + 0x192b0808, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, 0x2b081908, 0x08080819, + 0x2b190808, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x08081919, 0x0808082b, + 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, 0x082b0808, 0x0808082b, + 0x19080819, 0x0808082b, 0x19081908, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b, + 0x2b080808, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, + 0x0808192b, 0x08081908, 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, + 0x08191919, 0x08081908, 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, + 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, 0x19082b08, 0x08081908, + 0x19190819, 0x08081908, 0x19191908, 0x08081908, 0x1919192b, 0x08081908, 0x192b0808, 0x08081908, + 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, 0x08080808, 0x08081919, + 0x0808082b, 0x08081919, 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08190819, 0x08081919, + 0x08191908, 0x08081919, 0x082b0808, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919, + 0x19190808, 0x08081919, 0x192b0819, 0x08081919, 0x2b080808, 0x08081919, 0x08080819, 0x0808192b, + 0x08081908, 0x0808192b, 0x08190808, 0x0808192b, 0x082b192b, 0x0808192b, 0x19080808, 0x0808192b, + 0x1908082b, 0x0808192b, 0x2b081908, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08, + 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08082b2b, 0x08082b08, 0x08190819, 0x08082b08, + 0x08191908, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, 0x19080819, 0x08082b08, + 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x19192b08, 0x08082b08, 0x2b080808, 0x08082b08, + 0x2b2b0808, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, 0x08081908, 0x08082b19, + 0x08190808, 0x08082b19, 0x19080808, 0x08082b19, 0x2b080819, 0x08082b19, 0x2b082b19, 0x08082b19, + 0x08080808, 0x08082b2b, 0x082b0808, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x2b19192b, 0x08082b2b, + 0x2b2b0808, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x0808192b, 0x08190808, + 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, 0x08191919, 0x08190808, + 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, 0x19080808, 0x08190808, + 0x1908082b, 0x08190808, 0x19081919, 0x08190808, 0x19082b08, 0x08190808, 0x19190819, 0x08190808, + 0x19191908, 0x08190808, 0x192b0808, 0x08190808, 0x192b2b2b, 0x08190808, 0x2b080819, 0x08190808, + 0x2b081908, 0x08190808, 0x2b190808, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, + 0x08081919, 0x08190819, 0x08082b08, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819, + 0x082b0808, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, 0x19190808, 0x08190819, + 0x2b080808, 0x08190819, 0x2b191908, 0x08190819, 0x2b19192b, 0x08190819, 0x08080819, 0x0819082b, + 0x08081908, 0x0819082b, 0x0808192b, 0x0819082b, 0x08190808, 0x0819082b, 0x19080808, 0x0819082b, + 0x192b0808, 0x0819082b, 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, + 0x08082b08, 0x08191908, 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x082b0808, 0x08191908, + 0x19080819, 0x08191908, 0x19081908, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908, + 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x08080819, 0x08191919, 0x08081908, 0x08191919, + 0x08190808, 0x08191919, 0x19080808, 0x08191919, 0x08080808, 0x0819192b, 0x08191908, 0x0819192b, + 0x19082b19, 0x0819192b, 0x08080819, 0x08192b08, 0x08081908, 0x08192b08, 0x08190808, 0x08192b08, + 0x0819082b, 0x08192b08, 0x19080808, 0x08192b08, 0x19191908, 0x08192b08, 0x2b08192b, 0x08192b08, + 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x192b192b, 0x08192b19, 0x19190819, 0x08192b2b, + 0x2b2b2b19, 0x08192b2b, 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, + 0x08082b08, 0x082b0808, 0x08082b2b, 0x082b0808, 0x08190819, 0x082b0808, 0x08191908, 0x082b0808, + 0x082b0808, 0x082b0808, 0x19080819, 0x082b0808, 0x19081908, 0x082b0808, 0x19190808, 0x082b0808, + 0x2b080808, 0x082b0808, 0x2b2b0808, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, + 0x08190808, 0x082b0819, 0x19080808, 0x082b0819, 0x19082b08, 0x082b0819, 0x192b1919, 0x082b0819, + 0x08080808, 0x082b082b, 0x082b082b, 0x082b082b, 0x2b080808, 0x082b082b, 0x2b2b2b08, 0x082b082b, + 0x08080819, 0x082b1908, 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x082b2b19, 0x082b1908, + 0x19080808, 0x082b1908, 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x1919082b, 0x082b1919, + 0x2b192b19, 0x082b1919, 0x08080819, 0x082b192b, 0x08192b2b, 0x082b192b, 0x2b2b192b, 0x082b192b, + 0x08080808, 0x082b2b08, 0x08082b08, 0x082b2b08, 0x08082b2b, 0x082b2b08, 0x082b0808, 0x082b2b08, + 0x19191919, 0x082b2b08, 0x2b082b08, 0x082b2b08, 0x2b2b082b, 0x082b2b08, 0x192b2b08, 0x082b2b19, + 0x2b190808, 0x082b2b19, 0x08082b08, 0x082b2b2b, 0x082b0808, 0x082b2b2b, 0x2b08082b, 0x082b2b2b, + 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, 0x08081908, 0x19080808, + 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, 0x0819082b, 0x19080808, + 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, + 0x19080808, 0x19080808, 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, + 0x19082b2b, 0x19080808, 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x192b0808, 0x19080808, + 0x192b1919, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, 0x2b190808, 0x19080808, + 0x08080808, 0x19080819, 0x0808082b, 0x19080819, 0x08081919, 0x19080819, 0x08082b08, 0x19080819, + 0x08190819, 0x19080819, 0x08191908, 0x19080819, 0x082b0808, 0x19080819, 0x19080819, 0x19080819, + 0x19081908, 0x19080819, 0x19190808, 0x19080819, 0x2b080808, 0x19080819, 0x2b081919, 0x19080819, + 0x2b2b082b, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, 0x08190808, 0x1908082b, + 0x0819082b, 0x1908082b, 0x082b2b19, 0x1908082b, 0x19080808, 0x1908082b, 0x08080808, 0x19081908, + 0x0808082b, 0x19081908, 0x08081919, 0x19081908, 0x08082b08, 0x19081908, 0x08190819, 0x19081908, + 0x08191908, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x19080819, 0x19081908, + 0x19081908, 0x19081908, 0x19190808, 0x19081908, 0x2b080808, 0x19081908, 0x2b191908, 0x19081908, + 0x08080819, 0x19081919, 0x08081908, 0x19081919, 0x08190808, 0x19081919, 0x082b1908, 0x19081919, + 0x19080808, 0x19081919, 0x2b192b2b, 0x19081919, 0x08080808, 0x1908192b, 0x08082b2b, 0x1908192b, + 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x08080819, 0x19082b08, 0x08081908, 0x19082b08, + 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, 0x19191908, 0x19082b08, + 0x192b082b, 0x19082b08, 0x08080808, 0x19082b19, 0x08190819, 0x19082b19, 0x19081908, 0x19082b19, + 0x19190808, 0x19082b19, 0x192b2b19, 0x19082b19, 0x08081908, 0x19082b2b, 0x08080808, 0x19190808, + 0x0808082b, 0x19190808, 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, + 0x08191908, 0x19190808, 0x082b0808, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, + 0x19081908, 0x19190808, 0x19190808, 0x19190808, 0x2b080808, 0x19190808, 0x08080819, 0x19190819, + 0x08081908, 0x19190819, 0x08190808, 0x19190819, 0x08191919, 0x19190819, 0x19080808, 0x19190819, + 0x1908082b, 0x19190819, 0x08080808, 0x1919082b, 0x19081908, 0x1919082b, 0x2b2b2b2b, 0x1919082b, + 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x08190808, 0x19191908, 0x082b0819, 0x19191908, + 0x19080808, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b2b0819, 0x19191908, + 0x08080808, 0x19191919, 0x08082b08, 0x19191919, 0x2b080808, 0x19191919, 0x2b082b08, 0x19191919, + 0x082b0819, 0x1919192b, 0x192b2b08, 0x1919192b, 0x2b2b0819, 0x1919192b, 0x08080808, 0x19192b08, + 0x08191908, 0x19192b08, 0x19080819, 0x19192b08, 0x19190808, 0x19192b08, 0x2b192b19, 0x19192b08, + 0x08192b2b, 0x19192b19, 0x19080808, 0x19192b19, 0x1908082b, 0x19192b19, 0x2b081919, 0x19192b2b, + 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, + 0x19191908, 0x192b0808, 0x192b082b, 0x192b0808, 0x2b08192b, 0x192b0808, 0x2b2b2b19, 0x192b0808, + 0x08080808, 0x192b0819, 0x082b1908, 0x192b082b, 0x19082b2b, 0x192b082b, 0x2b19082b, 0x192b082b, + 0x08080808, 0x192b1908, 0x0819192b, 0x192b1908, 0x08190808, 0x192b1919, 0x19080808, 0x192b1919, + 0x19081919, 0x192b1919, 0x2b2b1908, 0x192b1919, 0x08080819, 0x192b2b08, 0x192b2b2b, 0x192b2b08, + 0x082b1919, 0x192b2b19, 0x0808192b, 0x192b2b2b, 0x19191908, 0x192b2b2b, 0x192b082b, 0x192b2b2b, + 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, + 0x08190819, 0x2b080808, 0x08191908, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b2b2b, 0x2b080808, + 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x2b080808, 0x2b080808, + 0x2b08082b, 0x2b080808, 0x2b2b2b08, 0x2b080808, 0x2b2b2b2b, 0x2b080808, 0x08080819, 0x2b080819, + 0x08081908, 0x2b080819, 0x0808192b, 0x2b080819, 0x08190808, 0x2b080819, 0x19080808, 0x2b080819, + 0x19190819, 0x2b080819, 0x19192b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x082b0808, 0x2b08082b, + 0x2b080808, 0x2b08082b, 0x2b08082b, 0x2b08082b, 0x2b2b0808, 0x2b08082b, 0x2b2b2b08, 0x2b08082b, + 0x08080819, 0x2b081908, 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908, + 0x08191919, 0x2b081908, 0x19080808, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b082b19, 0x2b081908, + 0x08080808, 0x2b081919, 0x19081908, 0x2b081919, 0x2b2b1919, 0x2b081919, 0x08192b08, 0x2b08192b, + 0x192b2b2b, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08082b08, 0x2b082b08, 0x082b1919, 0x2b082b08, + 0x19192b2b, 0x2b082b08, 0x2b080808, 0x2b082b08, 0x2b08082b, 0x2b082b08, 0x2b2b2b08, 0x2b082b08, + 0x0808192b, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x2b080808, 0x2b082b2b, 0x2b082b08, 0x2b082b2b, + 0x2b19192b, 0x2b082b2b, 0x2b2b2b08, 0x2b082b2b, 0x08080819, 0x2b190808, 0x08081908, 0x2b190808, + 0x08190808, 0x2b190808, 0x19080808, 0x2b190808, 0x1919192b, 0x2b190808, 0x2b081908, 0x2b190808, + 0x08080808, 0x2b190819, 0x082b082b, 0x2b190819, 0x192b1908, 0x2b190819, 0x1919192b, 0x2b19082b, + 0x2b082b19, 0x2b19082b, 0x08080808, 0x2b191908, 0x08081919, 0x2b191908, 0x19081908, 0x2b191908, + 0x19190808, 0x2b191908, 0x19192b08, 0x2b191908, 0x082b2b19, 0x2b191919, 0x2b190808, 0x2b191919, + 0x2b19082b, 0x2b191919, 0x19080819, 0x2b19192b, 0x19190819, 0x2b192b08, 0x2b2b192b, 0x2b192b08, + 0x19082b19, 0x2b192b19, 0x08191919, 0x2b192b2b, 0x192b0808, 0x2b192b2b, 0x08080808, 0x2b2b0808, + 0x0808082b, 0x2b2b0808, 0x08082b08, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, 0x082b0808, 0x2b2b0808, + 0x082b2b2b, 0x2b2b0808, 0x2b2b0808, 0x2b2b0808, 0x19190819, 0x2b2b0819, 0x19192b19, 0x2b2b0819, + 0x2b2b192b, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x0808082b, 0x2b2b082b, 0x08082b08, 0x2b2b082b, + 0x082b2b2b, 0x2b2b082b, 0x2b080808, 0x2b2b082b, 0x2b2b0808, 0x2b2b082b, 0x19080808, 0x2b2b1908, + 0x2b191919, 0x2b2b1908, 0x192b1919, 0x2b2b192b, 0x2b192b08, 0x2b2b192b, 0x08082b2b, 0x2b2b2b08, + 0x082b0808, 0x2b2b2b08, 0x082b082b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b0808, 0x2b2b2b08, + 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, + 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b +); + +struct iq2_xs { + d: f16, + qs: array, + scales: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 32; ib += 4) { + let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + for (var l: u32 = 0; l < 4; l++) { + let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let ig = (qs_val & 511) * 8; + let is = qs_val >> 9; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ2_XS) + +#decl(IQ2_S) + +const iq2s_grid = array( + 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, + 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, + 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808, + 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808, + 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808, + 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x192b192b, 0x08080808, + 0x192b2b19, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, + 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, 0x2b191908, 0x08080808, 0x2b2b0808, 0x08080808, + 0x2b2b1919, 0x08080808, 0x2b2b2b2b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, + 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, 0x0819082b, 0x08080819, + 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x082b0819, 0x08080819, 0x082b1908, 0x08080819, + 0x19080808, 0x08080819, 0x1908082b, 0x08080819, 0x19081919, 0x08080819, 0x19082b08, 0x08080819, + 0x19190819, 0x08080819, 0x19191908, 0x08080819, 0x1919192b, 0x08080819, 0x19192b19, 0x08080819, + 0x192b0808, 0x08080819, 0x192b1919, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, + 0x2b081908, 0x08080819, 0x2b190808, 0x08080819, 0x2b19082b, 0x08080819, 0x2b191919, 0x08080819, + 0x2b2b0819, 0x08080819, 0x2b2b1908, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, + 0x08081919, 0x0808082b, 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, + 0x082b0808, 0x0808082b, 0x082b2b2b, 0x0808082b, 0x19080819, 0x0808082b, 0x19081908, 0x0808082b, + 0x1908192b, 0x0808082b, 0x19082b19, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b, + 0x2b080808, 0x0808082b, 0x2b081919, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x2b191908, 0x0808082b, + 0x2b2b082b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x0808192b, 0x08081908, + 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, 0x08191919, 0x08081908, + 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, 0x082b192b, 0x08081908, + 0x082b2b19, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, + 0x19082b08, 0x08081908, 0x19082b2b, 0x08081908, 0x19190819, 0x08081908, 0x19191908, 0x08081908, + 0x1919192b, 0x08081908, 0x19192b19, 0x08081908, 0x192b0808, 0x08081908, 0x192b082b, 0x08081908, + 0x192b1919, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b08192b, 0x08081908, + 0x2b082b19, 0x08081908, 0x2b190808, 0x08081908, 0x2b191919, 0x08081908, 0x2b192b08, 0x08081908, + 0x2b2b0819, 0x08081908, 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, + 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08082b2b, 0x08081919, 0x08190819, 0x08081919, + 0x08191908, 0x08081919, 0x0819192b, 0x08081919, 0x08192b19, 0x08081919, 0x082b0808, 0x08081919, + 0x082b1919, 0x08081919, 0x082b2b08, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919, + 0x1908192b, 0x08081919, 0x19082b19, 0x08081919, 0x19190808, 0x08081919, 0x1919082b, 0x08081919, + 0x19191919, 0x08081919, 0x19192b08, 0x08081919, 0x192b0819, 0x08081919, 0x192b1908, 0x08081919, + 0x2b080808, 0x08081919, 0x2b08082b, 0x08081919, 0x2b081919, 0x08081919, 0x2b082b08, 0x08081919, + 0x2b190819, 0x08081919, 0x2b191908, 0x08081919, 0x2b2b0808, 0x08081919, 0x08080819, 0x0808192b, + 0x08081908, 0x0808192b, 0x0808192b, 0x0808192b, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, + 0x08191919, 0x0808192b, 0x19080808, 0x0808192b, 0x19081919, 0x0808192b, 0x19082b08, 0x0808192b, + 0x19190819, 0x0808192b, 0x19191908, 0x0808192b, 0x192b0808, 0x0808192b, 0x2b080819, 0x0808192b, + 0x2b081908, 0x0808192b, 0x2b190808, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08, + 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08190819, 0x08082b08, 0x08191908, 0x08082b08, + 0x0819192b, 0x08082b08, 0x08192b19, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, + 0x082b2b2b, 0x08082b08, 0x19080819, 0x08082b08, 0x19081908, 0x08082b08, 0x1908192b, 0x08082b08, + 0x19082b19, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x19191919, 0x08082b08, + 0x19192b08, 0x08082b08, 0x192b0819, 0x08082b08, 0x192b1908, 0x08082b08, 0x2b080808, 0x08082b08, + 0x2b081919, 0x08082b08, 0x2b191908, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, + 0x08081908, 0x08082b19, 0x08190808, 0x08082b19, 0x0819082b, 0x08082b19, 0x08191919, 0x08082b19, + 0x08192b08, 0x08082b19, 0x082b0819, 0x08082b19, 0x19080808, 0x08082b19, 0x19081919, 0x08082b19, + 0x19082b08, 0x08082b19, 0x19190819, 0x08082b19, 0x19191908, 0x08082b19, 0x192b0808, 0x08082b19, + 0x2b080819, 0x08082b19, 0x2b190808, 0x08082b19, 0x08080808, 0x08082b2b, 0x08190819, 0x08082b2b, + 0x08191908, 0x08082b2b, 0x082b082b, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x082b2b2b, 0x08082b2b, + 0x19190808, 0x08082b2b, 0x2b192b19, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, + 0x0808192b, 0x08190808, 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, + 0x08191919, 0x08190808, 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, + 0x082b192b, 0x08190808, 0x19080808, 0x08190808, 0x1908082b, 0x08190808, 0x19081919, 0x08190808, + 0x19082b08, 0x08190808, 0x19190819, 0x08190808, 0x19191908, 0x08190808, 0x1919192b, 0x08190808, + 0x19192b19, 0x08190808, 0x192b0808, 0x08190808, 0x192b082b, 0x08190808, 0x192b1919, 0x08190808, + 0x192b2b08, 0x08190808, 0x2b080819, 0x08190808, 0x2b081908, 0x08190808, 0x2b08192b, 0x08190808, + 0x2b190808, 0x08190808, 0x2b191919, 0x08190808, 0x2b192b08, 0x08190808, 0x2b2b0819, 0x08190808, + 0x2b2b1908, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, 0x08081919, 0x08190819, + 0x08082b08, 0x08190819, 0x08082b2b, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819, + 0x0819192b, 0x08190819, 0x08192b19, 0x08190819, 0x082b0808, 0x08190819, 0x082b082b, 0x08190819, + 0x082b1919, 0x08190819, 0x082b2b08, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, + 0x1908192b, 0x08190819, 0x19082b19, 0x08190819, 0x19190808, 0x08190819, 0x1919082b, 0x08190819, + 0x19191919, 0x08190819, 0x19192b08, 0x08190819, 0x192b0819, 0x08190819, 0x192b1908, 0x08190819, + 0x2b080808, 0x08190819, 0x2b08082b, 0x08190819, 0x2b081919, 0x08190819, 0x2b082b08, 0x08190819, + 0x2b190819, 0x08190819, 0x2b191908, 0x08190819, 0x08080819, 0x0819082b, 0x08081908, 0x0819082b, + 0x08082b19, 0x0819082b, 0x08190808, 0x0819082b, 0x08191919, 0x0819082b, 0x082b0819, 0x0819082b, + 0x082b1908, 0x0819082b, 0x19080808, 0x0819082b, 0x19081919, 0x0819082b, 0x19190819, 0x0819082b, + 0x19191908, 0x0819082b, 0x2b080819, 0x0819082b, 0x2b081908, 0x0819082b, 0x2b190808, 0x0819082b, + 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, 0x08082b08, 0x08191908, + 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x0819192b, 0x08191908, 0x08192b19, 0x08191908, + 0x082b0808, 0x08191908, 0x082b1919, 0x08191908, 0x082b2b08, 0x08191908, 0x19080819, 0x08191908, + 0x19081908, 0x08191908, 0x1908192b, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908, + 0x1919082b, 0x08191908, 0x19191919, 0x08191908, 0x19192b08, 0x08191908, 0x192b0819, 0x08191908, + 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x2b08082b, 0x08191908, 0x2b081919, 0x08191908, + 0x2b082b08, 0x08191908, 0x2b190819, 0x08191908, 0x2b191908, 0x08191908, 0x2b2b0808, 0x08191908, + 0x08080819, 0x08191919, 0x08081908, 0x08191919, 0x0808192b, 0x08191919, 0x08082b19, 0x08191919, + 0x08190808, 0x08191919, 0x0819082b, 0x08191919, 0x08191919, 0x08191919, 0x08192b08, 0x08191919, + 0x082b0819, 0x08191919, 0x082b1908, 0x08191919, 0x19080808, 0x08191919, 0x1908082b, 0x08191919, + 0x19081919, 0x08191919, 0x19082b08, 0x08191919, 0x19190819, 0x08191919, 0x19191908, 0x08191919, + 0x192b0808, 0x08191919, 0x2b080819, 0x08191919, 0x2b081908, 0x08191919, 0x2b190808, 0x08191919, + 0x08080808, 0x0819192b, 0x08081919, 0x0819192b, 0x08082b08, 0x0819192b, 0x08190819, 0x0819192b, + 0x08191908, 0x0819192b, 0x082b0808, 0x0819192b, 0x19080819, 0x0819192b, 0x19081908, 0x0819192b, + 0x19190808, 0x0819192b, 0x2b080808, 0x0819192b, 0x2b2b2b2b, 0x0819192b, 0x08080819, 0x08192b08, + 0x08081908, 0x08192b08, 0x0808192b, 0x08192b08, 0x08082b19, 0x08192b08, 0x08190808, 0x08192b08, + 0x08191919, 0x08192b08, 0x08192b08, 0x08192b08, 0x082b0819, 0x08192b08, 0x19080808, 0x08192b08, + 0x1908082b, 0x08192b08, 0x19081919, 0x08192b08, 0x19082b08, 0x08192b08, 0x19190819, 0x08192b08, + 0x19191908, 0x08192b08, 0x192b0808, 0x08192b08, 0x2b080819, 0x08192b08, 0x2b081908, 0x08192b08, + 0x08080808, 0x08192b19, 0x0808082b, 0x08192b19, 0x08081919, 0x08192b19, 0x08082b08, 0x08192b19, + 0x08190819, 0x08192b19, 0x08191908, 0x08192b19, 0x082b0808, 0x08192b19, 0x19080819, 0x08192b19, + 0x19081908, 0x08192b19, 0x19190808, 0x08192b19, 0x192b2b19, 0x08192b19, 0x2b2b082b, 0x08192b19, + 0x08081908, 0x08192b2b, 0x08190808, 0x08192b2b, 0x19080808, 0x08192b2b, 0x1919192b, 0x08192b2b, + 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, 0x08082b08, 0x082b0808, + 0x08190819, 0x082b0808, 0x08191908, 0x082b0808, 0x0819192b, 0x082b0808, 0x08192b19, 0x082b0808, + 0x082b0808, 0x082b0808, 0x082b1919, 0x082b0808, 0x082b2b2b, 0x082b0808, 0x19080819, 0x082b0808, + 0x19081908, 0x082b0808, 0x19190808, 0x082b0808, 0x1919082b, 0x082b0808, 0x19191919, 0x082b0808, + 0x192b1908, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b082b2b, 0x082b0808, 0x2b191908, 0x082b0808, + 0x2b2b2b2b, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, 0x08190808, 0x082b0819, + 0x0819082b, 0x082b0819, 0x08191919, 0x082b0819, 0x082b0819, 0x082b0819, 0x19080808, 0x082b0819, + 0x1908082b, 0x082b0819, 0x19081919, 0x082b0819, 0x19190819, 0x082b0819, 0x19191908, 0x082b0819, + 0x192b0808, 0x082b0819, 0x2b080819, 0x082b0819, 0x2b081908, 0x082b0819, 0x2b190808, 0x082b0819, + 0x08080808, 0x082b082b, 0x08082b2b, 0x082b082b, 0x082b082b, 0x082b082b, 0x082b2b08, 0x082b082b, + 0x082b2b2b, 0x082b082b, 0x19081908, 0x082b082b, 0x19190808, 0x082b082b, 0x2b082b08, 0x082b082b, + 0x2b082b2b, 0x082b082b, 0x2b2b2b08, 0x082b082b, 0x08080819, 0x082b1908, 0x08081908, 0x082b1908, + 0x0808192b, 0x082b1908, 0x08082b19, 0x082b1908, 0x08190808, 0x082b1908, 0x08191919, 0x082b1908, + 0x08192b08, 0x082b1908, 0x082b0819, 0x082b1908, 0x082b1908, 0x082b1908, 0x19080808, 0x082b1908, + 0x1908082b, 0x082b1908, 0x19081919, 0x082b1908, 0x19082b08, 0x082b1908, 0x19190819, 0x082b1908, + 0x19191908, 0x082b1908, 0x192b0808, 0x082b1908, 0x2b080819, 0x082b1908, 0x2b081908, 0x082b1908, + 0x2b190808, 0x082b1908, 0x08080808, 0x082b1919, 0x08081919, 0x082b1919, 0x08082b08, 0x082b1919, + 0x08190819, 0x082b1919, 0x08191908, 0x082b1919, 0x082b0808, 0x082b1919, 0x19080819, 0x082b1919, + 0x19081908, 0x082b1919, 0x19190808, 0x082b1919, 0x192b192b, 0x082b1919, 0x2b080808, 0x082b1919, + 0x08080819, 0x082b192b, 0x08081908, 0x082b192b, 0x08190808, 0x082b192b, 0x19080808, 0x082b192b, + 0x19192b19, 0x082b192b, 0x08080808, 0x082b2b08, 0x08081919, 0x082b2b08, 0x08190819, 0x082b2b08, + 0x08191908, 0x082b2b08, 0x19080819, 0x082b2b08, 0x19081908, 0x082b2b08, 0x19190808, 0x082b2b08, + 0x2b082b2b, 0x082b2b08, 0x2b2b2b2b, 0x082b2b08, 0x08080819, 0x082b2b19, 0x08081908, 0x082b2b19, + 0x08190808, 0x082b2b19, 0x2b191919, 0x082b2b19, 0x08082b2b, 0x082b2b2b, 0x082b082b, 0x082b2b2b, + 0x192b1908, 0x082b2b2b, 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, + 0x08081908, 0x19080808, 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, + 0x0819082b, 0x19080808, 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x08192b2b, 0x19080808, + 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x082b192b, 0x19080808, 0x19080808, 0x19080808, + 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, 0x19082b2b, 0x19080808, + 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x1919192b, 0x19080808, 0x19192b19, 0x19080808, + 0x192b0808, 0x19080808, 0x192b082b, 0x19080808, 0x192b1919, 0x19080808, 0x2b080819, 0x19080808, + 0x2b081908, 0x19080808, 0x2b190808, 0x19080808, 0x2b191919, 0x19080808, 0x2b192b08, 0x19080808, + 0x2b2b0819, 0x19080808, 0x2b2b1908, 0x19080808, 0x08080808, 0x19080819, 0x0808082b, 0x19080819, + 0x08081919, 0x19080819, 0x08082b08, 0x19080819, 0x08190819, 0x19080819, 0x08191908, 0x19080819, + 0x0819192b, 0x19080819, 0x08192b19, 0x19080819, 0x082b0808, 0x19080819, 0x082b082b, 0x19080819, + 0x082b1919, 0x19080819, 0x19080819, 0x19080819, 0x19081908, 0x19080819, 0x1908192b, 0x19080819, + 0x19082b19, 0x19080819, 0x19190808, 0x19080819, 0x1919082b, 0x19080819, 0x19191919, 0x19080819, + 0x19192b08, 0x19080819, 0x192b0819, 0x19080819, 0x192b1908, 0x19080819, 0x2b080808, 0x19080819, + 0x2b08082b, 0x19080819, 0x2b081919, 0x19080819, 0x2b082b08, 0x19080819, 0x2b190819, 0x19080819, + 0x2b191908, 0x19080819, 0x2b2b0808, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, + 0x08190808, 0x1908082b, 0x0819082b, 0x1908082b, 0x08191919, 0x1908082b, 0x08192b08, 0x1908082b, + 0x082b1908, 0x1908082b, 0x19080808, 0x1908082b, 0x19081919, 0x1908082b, 0x19082b08, 0x1908082b, + 0x19190819, 0x1908082b, 0x19191908, 0x1908082b, 0x192b0808, 0x1908082b, 0x2b080819, 0x1908082b, + 0x2b081908, 0x1908082b, 0x08080808, 0x19081908, 0x0808082b, 0x19081908, 0x08081919, 0x19081908, + 0x08082b08, 0x19081908, 0x08082b2b, 0x19081908, 0x08190819, 0x19081908, 0x08191908, 0x19081908, + 0x0819192b, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x082b082b, 0x19081908, + 0x082b1919, 0x19081908, 0x082b2b08, 0x19081908, 0x19080819, 0x19081908, 0x19081908, 0x19081908, + 0x1908192b, 0x19081908, 0x19082b19, 0x19081908, 0x19190808, 0x19081908, 0x1919082b, 0x19081908, + 0x19191919, 0x19081908, 0x19192b08, 0x19081908, 0x192b0819, 0x19081908, 0x192b1908, 0x19081908, + 0x2b080808, 0x19081908, 0x2b08082b, 0x19081908, 0x2b081919, 0x19081908, 0x2b082b08, 0x19081908, + 0x2b190819, 0x19081908, 0x2b191908, 0x19081908, 0x2b2b0808, 0x19081908, 0x08080819, 0x19081919, + 0x08081908, 0x19081919, 0x0808192b, 0x19081919, 0x08082b19, 0x19081919, 0x08190808, 0x19081919, + 0x0819082b, 0x19081919, 0x08191919, 0x19081919, 0x08192b08, 0x19081919, 0x082b0819, 0x19081919, + 0x082b1908, 0x19081919, 0x19080808, 0x19081919, 0x1908082b, 0x19081919, 0x19081919, 0x19081919, + 0x19082b08, 0x19081919, 0x19190819, 0x19081919, 0x19191908, 0x19081919, 0x192b0808, 0x19081919, + 0x192b2b2b, 0x19081919, 0x2b080819, 0x19081919, 0x2b081908, 0x19081919, 0x2b190808, 0x19081919, + 0x08080808, 0x1908192b, 0x0808082b, 0x1908192b, 0x08081919, 0x1908192b, 0x08082b08, 0x1908192b, + 0x08190819, 0x1908192b, 0x08191908, 0x1908192b, 0x082b0808, 0x1908192b, 0x19080819, 0x1908192b, + 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x2b080808, 0x1908192b, 0x2b2b1919, 0x1908192b, + 0x08080819, 0x19082b08, 0x08081908, 0x19082b08, 0x08082b19, 0x19082b08, 0x08190808, 0x19082b08, + 0x0819082b, 0x19082b08, 0x08191919, 0x19082b08, 0x08192b08, 0x19082b08, 0x082b0819, 0x19082b08, + 0x082b1908, 0x19082b08, 0x19080808, 0x19082b08, 0x1908082b, 0x19082b08, 0x19081919, 0x19082b08, + 0x19082b08, 0x19082b08, 0x19190819, 0x19082b08, 0x19191908, 0x19082b08, 0x192b0808, 0x19082b08, + 0x2b081908, 0x19082b08, 0x2b190808, 0x19082b08, 0x08080808, 0x19082b19, 0x0808082b, 0x19082b19, + 0x08081919, 0x19082b19, 0x08082b08, 0x19082b19, 0x08190819, 0x19082b19, 0x08191908, 0x19082b19, + 0x082b0808, 0x19082b19, 0x19080819, 0x19082b19, 0x19081908, 0x19082b19, 0x19190808, 0x19082b19, + 0x2b080808, 0x19082b19, 0x2b19192b, 0x19082b19, 0x08080819, 0x19082b2b, 0x08081908, 0x19082b2b, + 0x08190808, 0x19082b2b, 0x19080808, 0x19082b2b, 0x08080808, 0x19190808, 0x0808082b, 0x19190808, + 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, 0x08191908, 0x19190808, + 0x0819192b, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x082b082b, 0x19190808, + 0x082b1919, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, 0x19081908, 0x19190808, + 0x1908192b, 0x19190808, 0x19082b19, 0x19190808, 0x19190808, 0x19190808, 0x1919082b, 0x19190808, + 0x19191919, 0x19190808, 0x19192b08, 0x19190808, 0x192b0819, 0x19190808, 0x192b1908, 0x19190808, + 0x2b080808, 0x19190808, 0x2b08082b, 0x19190808, 0x2b081919, 0x19190808, 0x2b082b08, 0x19190808, + 0x2b190819, 0x19190808, 0x2b191908, 0x19190808, 0x08080819, 0x19190819, 0x08081908, 0x19190819, + 0x0808192b, 0x19190819, 0x08082b19, 0x19190819, 0x08190808, 0x19190819, 0x0819082b, 0x19190819, + 0x08191919, 0x19190819, 0x08192b08, 0x19190819, 0x082b0819, 0x19190819, 0x082b1908, 0x19190819, + 0x19080808, 0x19190819, 0x1908082b, 0x19190819, 0x19081919, 0x19190819, 0x19082b08, 0x19190819, + 0x19190819, 0x19190819, 0x19191908, 0x19190819, 0x192b0808, 0x19190819, 0x2b080819, 0x19190819, + 0x2b081908, 0x19190819, 0x2b190808, 0x19190819, 0x08080808, 0x1919082b, 0x08081919, 0x1919082b, + 0x08082b08, 0x1919082b, 0x08190819, 0x1919082b, 0x08191908, 0x1919082b, 0x082b0808, 0x1919082b, + 0x19080819, 0x1919082b, 0x19081908, 0x1919082b, 0x19190808, 0x1919082b, 0x192b2b19, 0x1919082b, + 0x2b080808, 0x1919082b, 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x0808192b, 0x19191908, + 0x08082b19, 0x19191908, 0x08190808, 0x19191908, 0x0819082b, 0x19191908, 0x08191919, 0x19191908, + 0x08192b08, 0x19191908, 0x082b0819, 0x19191908, 0x082b1908, 0x19191908, 0x19080808, 0x19191908, + 0x1908082b, 0x19191908, 0x19081919, 0x19191908, 0x19082b08, 0x19191908, 0x19190819, 0x19191908, + 0x19191908, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b081908, 0x19191908, + 0x2b190808, 0x19191908, 0x08080808, 0x19191919, 0x0808082b, 0x19191919, 0x08081919, 0x19191919, + 0x08082b08, 0x19191919, 0x08190819, 0x19191919, 0x08191908, 0x19191919, 0x082b0808, 0x19191919, + 0x19080819, 0x19191919, 0x19081908, 0x19191919, 0x19190808, 0x19191919, 0x2b080808, 0x19191919, + 0x08080819, 0x1919192b, 0x08081908, 0x1919192b, 0x08190808, 0x1919192b, 0x082b192b, 0x1919192b, + 0x19080808, 0x1919192b, 0x08080808, 0x19192b08, 0x0808082b, 0x19192b08, 0x08081919, 0x19192b08, + 0x08082b08, 0x19192b08, 0x08190819, 0x19192b08, 0x08191908, 0x19192b08, 0x082b0808, 0x19192b08, + 0x19080819, 0x19192b08, 0x19081908, 0x19192b08, 0x19190808, 0x19192b08, 0x19192b2b, 0x19192b08, + 0x2b080808, 0x19192b08, 0x08080819, 0x19192b19, 0x08081908, 0x19192b19, 0x08190808, 0x19192b19, + 0x19080808, 0x19192b19, 0x08080808, 0x19192b2b, 0x08192b19, 0x19192b2b, 0x2b081919, 0x19192b2b, + 0x2b2b2b08, 0x19192b2b, 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x0808192b, 0x192b0808, + 0x08190808, 0x192b0808, 0x0819082b, 0x192b0808, 0x08191919, 0x192b0808, 0x08192b08, 0x192b0808, + 0x082b0819, 0x192b0808, 0x082b1908, 0x192b0808, 0x19080808, 0x192b0808, 0x19081919, 0x192b0808, + 0x19082b08, 0x192b0808, 0x19190819, 0x192b0808, 0x19191908, 0x192b0808, 0x192b0808, 0x192b0808, + 0x2b081908, 0x192b0808, 0x2b190808, 0x192b0808, 0x08080808, 0x192b0819, 0x0808082b, 0x192b0819, + 0x08081919, 0x192b0819, 0x08082b08, 0x192b0819, 0x08190819, 0x192b0819, 0x08191908, 0x192b0819, + 0x082b0808, 0x192b0819, 0x19080819, 0x192b0819, 0x19081908, 0x192b0819, 0x19190808, 0x192b0819, + 0x2b080808, 0x192b0819, 0x2b192b19, 0x192b0819, 0x08081908, 0x192b082b, 0x08190808, 0x192b082b, + 0x19080808, 0x192b082b, 0x1919192b, 0x192b082b, 0x2b2b0819, 0x192b082b, 0x08080808, 0x192b1908, + 0x08081919, 0x192b1908, 0x08082b08, 0x192b1908, 0x08190819, 0x192b1908, 0x08191908, 0x192b1908, + 0x082b0808, 0x192b1908, 0x19080819, 0x192b1908, 0x19081908, 0x192b1908, 0x19190808, 0x192b1908, + 0x2b080808, 0x192b1908, 0x08080819, 0x192b1919, 0x08081908, 0x192b1919, 0x08190808, 0x192b1919, + 0x19080808, 0x192b1919, 0x19082b2b, 0x192b1919, 0x192b2b08, 0x192b1919, 0x2b19082b, 0x192b1919, + 0x08080808, 0x192b192b, 0x2b191908, 0x192b192b, 0x08080819, 0x192b2b08, 0x08081908, 0x192b2b08, + 0x08190808, 0x192b2b08, 0x192b1919, 0x192b2b08, 0x2b192b08, 0x192b2b08, 0x08080808, 0x192b2b19, + 0x082b2b2b, 0x192b2b19, 0x1908082b, 0x192b2b2b, 0x2b2b0819, 0x192b2b2b, 0x08080808, 0x2b080808, + 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, 0x08190819, 0x2b080808, + 0x08191908, 0x2b080808, 0x08192b19, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b1919, 0x2b080808, + 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x1919082b, 0x2b080808, + 0x19191919, 0x2b080808, 0x19192b08, 0x2b080808, 0x192b0819, 0x2b080808, 0x2b080808, 0x2b080808, + 0x2b081919, 0x2b080808, 0x2b190819, 0x2b080808, 0x2b191908, 0x2b080808, 0x08080819, 0x2b080819, + 0x08081908, 0x2b080819, 0x08082b19, 0x2b080819, 0x08190808, 0x2b080819, 0x0819082b, 0x2b080819, + 0x08191919, 0x2b080819, 0x08192b08, 0x2b080819, 0x082b0819, 0x2b080819, 0x082b1908, 0x2b080819, + 0x19080808, 0x2b080819, 0x1908082b, 0x2b080819, 0x19081919, 0x2b080819, 0x19082b08, 0x2b080819, + 0x19190819, 0x2b080819, 0x19191908, 0x2b080819, 0x2b080819, 0x2b080819, 0x2b081908, 0x2b080819, + 0x2b190808, 0x2b080819, 0x2b2b2b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x08081919, 0x2b08082b, + 0x08082b2b, 0x2b08082b, 0x08190819, 0x2b08082b, 0x08191908, 0x2b08082b, 0x19080819, 0x2b08082b, + 0x19081908, 0x2b08082b, 0x19190808, 0x2b08082b, 0x08080819, 0x2b081908, 0x08081908, 0x2b081908, + 0x0808192b, 0x2b081908, 0x08082b19, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908, + 0x08191919, 0x2b081908, 0x08192b08, 0x2b081908, 0x082b0819, 0x2b081908, 0x19080808, 0x2b081908, + 0x1908082b, 0x2b081908, 0x19081919, 0x2b081908, 0x19082b08, 0x2b081908, 0x19190819, 0x2b081908, + 0x19191908, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b080819, 0x2b081908, 0x2b081908, 0x2b081908, + 0x2b190808, 0x2b081908, 0x08080808, 0x2b081919, 0x0808082b, 0x2b081919, 0x08081919, 0x2b081919, + 0x08082b08, 0x2b081919, 0x08190819, 0x2b081919, 0x08191908, 0x2b081919, 0x082b0808, 0x2b081919, + 0x19080819, 0x2b081919, 0x19081908, 0x2b081919, 0x19190808, 0x2b081919, 0x2b080808, 0x2b081919, + 0x2b082b2b, 0x2b081919, 0x08080819, 0x2b08192b, 0x08081908, 0x2b08192b, 0x08190808, 0x2b08192b, + 0x082b2b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08081919, 0x2b082b08, + 0x08190819, 0x2b082b08, 0x08191908, 0x2b082b08, 0x19080819, 0x2b082b08, 0x19081908, 0x2b082b08, + 0x19190808, 0x2b082b08, 0x2b2b082b, 0x2b082b08, 0x08080819, 0x2b082b19, 0x08081908, 0x2b082b19, + 0x19080808, 0x2b082b19, 0x192b1919, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x19192b08, 0x2b082b2b, + 0x19192b2b, 0x2b082b2b, 0x2b08082b, 0x2b082b2b, 0x2b2b082b, 0x2b082b2b, 0x08080819, 0x2b190808, + 0x08081908, 0x2b190808, 0x08082b19, 0x2b190808, 0x08190808, 0x2b190808, 0x0819082b, 0x2b190808, + 0x08191919, 0x2b190808, 0x08192b08, 0x2b190808, 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, + 0x1908082b, 0x2b190808, 0x19081919, 0x2b190808, 0x19082b08, 0x2b190808, 0x19190819, 0x2b190808, + 0x19191908, 0x2b190808, 0x192b0808, 0x2b190808, 0x2b080819, 0x2b190808, 0x2b081908, 0x2b190808, + 0x2b190808, 0x2b190808, 0x08080808, 0x2b190819, 0x08081919, 0x2b190819, 0x08190819, 0x2b190819, + 0x08191908, 0x2b190819, 0x19080819, 0x2b190819, 0x19081908, 0x2b190819, 0x19190808, 0x2b190819, + 0x19192b2b, 0x2b190819, 0x08080819, 0x2b19082b, 0x08081908, 0x2b19082b, 0x08190808, 0x2b19082b, + 0x19080808, 0x2b19082b, 0x2b2b192b, 0x2b19082b, 0x08080808, 0x2b191908, 0x0808082b, 0x2b191908, + 0x08081919, 0x2b191908, 0x08082b08, 0x2b191908, 0x08190819, 0x2b191908, 0x08191908, 0x2b191908, + 0x082b0808, 0x2b191908, 0x19080819, 0x2b191908, 0x19081908, 0x2b191908, 0x19190808, 0x2b191908, + 0x2b080808, 0x2b191908, 0x2b19192b, 0x2b191908, 0x08080819, 0x2b191919, 0x08081908, 0x2b191919, + 0x08190808, 0x2b191919, 0x19080808, 0x2b191919, 0x2b192b08, 0x2b191919, 0x2b2b0819, 0x2b191919, + 0x08080808, 0x2b19192b, 0x1908192b, 0x2b19192b, 0x192b1908, 0x2b19192b, 0x08080819, 0x2b192b08, + 0x08081908, 0x2b192b08, 0x08190808, 0x2b192b08, 0x082b192b, 0x2b192b08, 0x19080808, 0x2b192b08, + 0x2b2b2b19, 0x2b192b08, 0x08080808, 0x2b192b19, 0x19082b19, 0x2b192b19, 0x1919082b, 0x2b192b19, + 0x2b190808, 0x2b192b2b, 0x08080808, 0x2b2b0808, 0x08081919, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, + 0x08191908, 0x2b2b0808, 0x082b082b, 0x2b2b0808, 0x082b2b2b, 0x2b2b0808, 0x19080819, 0x2b2b0808, + 0x19081908, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b2b082b, 0x2b2b0808, 0x2b2b2b2b, 0x2b2b0808, + 0x19080808, 0x2b2b0819, 0x192b1919, 0x2b2b0819, 0x0808082b, 0x2b2b082b, 0x08082b2b, 0x2b2b082b, + 0x082b082b, 0x2b2b082b, 0x082b2b08, 0x2b2b082b, 0x082b2b2b, 0x2b2b082b, 0x2b08082b, 0x2b2b082b, + 0x2b082b08, 0x2b2b082b, 0x2b082b2b, 0x2b2b082b, 0x2b2b2b08, 0x2b2b082b, 0x08080819, 0x2b2b1908, + 0x08081908, 0x2b2b1908, 0x08190808, 0x2b2b1908, 0x19080808, 0x2b2b1908, 0x2b082b19, 0x2b2b1908, + 0x2b2b1908, 0x2b2b1908, 0x08080808, 0x2b2b1919, 0x08192b19, 0x2b2b1919, 0x19190819, 0x2b2b192b, + 0x08082b2b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b082b, 0x2b2b2b08, 0x19191908, 0x2b2b2b19, + 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, + 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b +); + +struct iq2_s { + d: f16, + qs: array, + qh: array, + scales: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; + for (var i: u32 = 0; i < 16; i++) { + qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var scale_vals = array( + bitcast(vec2(block.scales[0], block.scales[1])), + bitcast(vec2(block.scales[2], block.scales[3])) + ); + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib ++) { + let s = get_byte(scale_vals[ib / 4], ib % 4); + let db = array( + d * (0.5 + f32(s & 0xF)) * 0.25, + d * (0.5 + f32(s >> 4)) * 0.25 + ); + let qs_w = qs_vals[ib]; + for (var l: u32 = 0; l < 4; l++) { + let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; + let ig = (get_byte(qs_w, l) | qh_b) * 8; + let signs = get_byte(qs_vals[ib + 8], l); + let dl = db[l/2]; + for (var j: u32 = 0; j < 8; j++) { + let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); + sum += dl * f32(g) * m * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + + +#enddecl(IQ2_S) + +#decl(IQ3_XSS) + +const iq3xxs_grid = array( + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 +); + +struct iq3_xxs { + d: f16, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 16; ib += 2) { + let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; + for (var l: u32 = 0; l < 4; l++) { + let is = (sc_sign >> (7 * l)) & 127; + let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); + let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig1 = get_byte(ig_val, 0); + let ig2 = get_byte(ig_val, 1); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3xxs_grid[ig1], j); + let g2 = get_byte(iq3xxs_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += db * f32(g1) * m1 * src1[src1_i]; + sum += db * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + return sum; +} + +#enddecl(IQ3_XSS) + +#decl(IQ3_S) + +const iq3s_grid = array( + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 +); + +struct iq3_s { + d: f16, + qs: array, + qh: array, + signs: array, + scales: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( + bitcast(vec2(block.qh[0], block.qh[1])), + bitcast(vec2(block.qh[2], block.qh[3])) + ); + var sign_vals: array; + for (var i: u32 = 0; i < 8; i++) { + sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + } + var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + var sum = 0.0; + for (var ib: u32 = 0; ib < 4; ib++) { + let s = get_byte(scale_vals, ib); + let db = array( + d * (1.0 + 2.0 * f32(s & 0xF)), + d * (1.0 + 2.0 * f32(s >> 4)) + ); + for (var k: u32 = 0; k < 2; k++) { + let dl = db[k]; + let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); + let sign_w = sign_vals[ib * 2 + k]; + for (var l: u32 = 0; l < 4; l++) { + let signs = get_byte(sign_w, l); + let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); + let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); + for (var j: u32 = 0; j < 4; j++) { + let g1 = get_byte(iq3s_grid[ig1], j); + let g2 = get_byte(iq3s_grid[ig2], j); + let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); + let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); + sum += dl * f32(g1) * m1 * src1[src1_i]; + sum += dl * f32(g2) * m2 * src1[src1_i + 4]; + src1_i++; + } + src1_i += 4; + } + } + } + return sum; +} +#enddecl(IQ3_S) + +#decl(IQ1_TABLE) + +const IQ1_DELTA: f32 = 0.125; + +const iq1_grid = array( + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +); + +#enddecl(IQ1_TABLE) + +#decl(IQ1_S) + +struct iq1_s { + d: f16, + qs: array, + qh: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let qh = bitcast(vec2(block.qh[ib], 0.0)); + let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); + let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + for (var l: u32 = 0; l < 4; l++) { + let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl * (f32(gs) + delta) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ1_S) + +#decl(IQ1_M) + +struct iq1_m { + qs: array, + qh: array, + scales: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + + let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); + let d = f32(bitcast>(scale).x); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; + let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; + let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; + var dl = array( + d * f32(2 * s1 + 1), + d * f32(2 * s2 + 1) + ); + + let qh = block.qh[ib / 2] >> (16 * (ib % 2)); + var idx = array( + get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), + get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), + get_byte(block.qs[ib], 2) | ((qh) & 0x700), + get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) + ); + var delta = array( + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), + select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) + ); + for (var l: u32 = 0; l < 4; l++) { + let ig = idx[l] * 8; + for (var j: u32 = 0; j < 8; j++) { + let gw = iq1_grid[(ig + j) / 16]; + let g = (gw >> (((ig + j) % 16) * 2)) & 3; + let gs = bitcast(g << 30) >> 30; + sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; + src1_i++; + } + } + } + return sum; +} + +#enddecl(IQ1_M) + +#decl(IQ4_TABLE) + +const kvalues_iq4nl = array( + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +); + +#enddecl(IQ4_TABLE) + +#decl(IQ4_NL) + +struct iq4_nl { + d: f16, + qs: array, +} + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + var src1_i = src1_idx_base + offset * 32; + var sum = 0.0; + var qs: array; + for (var i: u32 = 0; i < 4; i++) { + qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + } + for (var j: u32 = 0; j < 16; j++) { + let qsb = get_byte(qs[j / 4], j % 4); + sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + return sum; +} + +#enddecl(IQ4_NL) + +#decl(IQ4_XS) + +struct iq4_xs { + d: f16, + scales_h: f16, + scales_l: u32, + qs: array +}; + +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { + let block = src0[src0_idx_base + offset]; + let d = f32(block.d); + let scales_h = bitcast(vec2(block.scales_h, 0.0)); + var src1_i = src1_idx_base + offset * 256; + var sum = 0.0; + for (var ib: u32 = 0; ib < 8; ib++) { + let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); + let dl = d * (f32(ls) - 32.0); + for (var j: u32 = 0; j < 16; j++) { + let iqs = ib * 16 + j; + let qsb = get_byte(block.qs[iqs / 4], iqs % 4); + sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; + sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; + src1_i++; + } + src1_i += 16; + } + return sum; +} + +#enddecl(IQ4_XS) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, // in elements/blocks + offset_src1: u32, // in elements/blocks + offset_dst: u32, // in elements/blocks + m: u32, + n: u32, + k: u32, + // all strides are in elements/blocks + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (global_id.x >= total) { + return; + } + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = global_id.x / dst3_stride; + let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension + let src13_idx = dst3_idx; // src1 is not broadcast + let dst3_rem = global_id.x % dst3_stride; + + let dst2_idx = dst3_rem / dst2_stride; + let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension + let src12_idx = dst2_idx; // src1 is not broadcast + + let dst2_rem = dst3_rem % dst2_stride; + + let row = dst2_rem / params.n; // output row + let col = dst2_rem % params.n; // output column + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; + + var sum = 0.0; + for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { + sum += multiply_add(src0_idx_base, src1_idx_base, i); + } + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index 054aab566..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,56 +0,0 @@ -struct MulMatParams { - m: u32, - n: u32, - k: u32, - // all strides are in elements - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; // N rows, K columns -@group(0) @binding(1) var src1: array; // M rows, K columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(64) -fn main(@builtin(global_invocation_id) global_id: vec3) { - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_id.x / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k; i = i + 1u) { - let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i; - let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i; - sum = sum + src0[src0_idx] * src1[src1_idx]; - } - dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; -} From e92734d51bcb82cc35f0a6b5a14928f0036b2c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 22 Aug 2025 23:47:01 +0200 Subject: [PATCH 2/7] test-opt: allow slight inprecision (#15503) --- tests/test-opt.cpp | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index f02b4cad8..18d3fcf2c 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -358,7 +358,7 @@ static std::pair test_forward_backward( double accuracy; double accuracy_unc; ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); - const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc); + const bool subtest_ok = ndata == 0 && almost_equal(loss, 0.0, 1e-6) && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc); helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass); } @@ -381,10 +381,12 @@ static std::pair test_forward_backward( { float weights; ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); - const bool subtest_ok = weights == ndata/2; + const bool subtest_ok = almost_equal(weights, ndata/2, 1e-10); helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass); } { + constexpr double atol = 1e-10; + int64_t ndata; ggml_opt_result_ndata(cd.result, &ndata); bool subtest_ok = ndata == 6; @@ -392,7 +394,7 @@ static std::pair test_forward_backward( double loss; double loss_unc; ggml_opt_result_loss(cd.result, &loss, &loss_unc); - subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10); + subtest_ok = subtest_ok && almost_equal(loss, 33.0, atol) && almost_equal(loss_unc, sqrt(3.5), atol); double accuracy; double accuracy_unc; @@ -437,7 +439,7 @@ static std::pair test_forward_backward( { float weights; ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); - const bool subtest_ok = weights == -ndata * .5; + const bool subtest_ok = almost_equal(weights, -ndata * 0.5, 1e-10); helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass); } { @@ -448,7 +450,7 @@ static std::pair test_forward_backward( double loss; double loss_unc; ggml_opt_result_loss(cd.result, &loss, &loss_unc); - subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0); + subtest_ok = subtest_ok && almost_equal(loss, 18.0, 1e-10) && (shuffle || loss_unc == 0.0); double accuracy; double accuracy_unc; @@ -550,10 +552,12 @@ static std::pair test_idata_split( if (adamw) { float weights; ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); - const bool subtest_ok = weights == ndata/2 - epoch*idata_split; + const bool subtest_ok = almost_equal(weights, ndata/2 - epoch*idata_split, 1e-10); helper_after_test_idata_split(optim, __func__, high_level, epoch, "weights", subtest_ok, ntest, npass); } if (adamw) { + constexpr double atol = 1e-10; + int64_t ndata_result; ggml_opt_result_ndata(cd.result, &ndata_result); bool subtest_ok = ndata_result == idata_split; @@ -561,7 +565,7 @@ static std::pair test_idata_split( double loss; double loss_unc; ggml_opt_result_loss(cd.result, &loss, &loss_unc); - subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0; + subtest_ok = subtest_ok && almost_equal(loss, 28.0 - epoch*16.0, atol) && almost_equal(loss_unc, 0.0, atol); double accuracy; double accuracy_unc; @@ -571,6 +575,8 @@ static std::pair test_idata_split( helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass); } if (adamw) { + constexpr double atol = 1e-10; + int64_t ndata_result; ggml_opt_result_ndata(cd.result2, &ndata_result); bool subtest_ok = ndata_result == ndata - idata_split; @@ -578,7 +584,7 @@ static std::pair test_idata_split( double loss; double loss_unc; ggml_opt_result_loss(cd.result2, &loss, &loss_unc); - subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10); + subtest_ok = subtest_ok && almost_equal(loss, 15.0 - epoch*8, atol) && almost_equal(loss_unc, sqrt(0.5), atol); double accuracy; double accuracy_unc; @@ -687,22 +693,24 @@ static std::pair test_gradient_accumulation( } bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW; if (adamw) { + constexpr double atol = 1e-6; float weights; ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); - const bool subtest_ok = weights == (ndata/2) - epoch; + const bool subtest_ok = almost_equal(weights, (ndata/2) - epoch, atol); helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass); } { + constexpr double atol = 1e-6; int64_t ndata_result; ggml_opt_result_ndata(cd.result, &ndata_result); - bool subtest_ok = ndata_result == ndata/nbatch_physical; + bool subtest_ok = almost_equal(ndata_result, ndata/nbatch_physical, atol); double loss; ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr); if (loss_type == GGML_OPT_LOSS_TYPE_SUM) { - subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0); + subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0), atol); } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { - subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6); + subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, atol); } else { GGML_ASSERT(false); } From 330c3d2d21b55bca5517db7d2eea2ea8f131df4a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 Aug 2025 01:31:54 -0500 Subject: [PATCH 3/7] vulkan: optimize mul_mat_id loading row ids into shared memory (#15427) - Spread the work across the whole workgroup. Using more threads seems to far outweigh the synchronization overhead. - Specialize the code for when the division is by a power of two. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 105 +++++++++++------- .../vulkan-shaders/mul_mm_cm2.comp | 103 ++++++++++------- 3 files changed, 133 insertions(+), 81 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index fb18a55cd..2c5678f48 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16, 0 }; - m_warptile_mmqid = { 256, 128, 64, 16, 0 }; - s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a61a464c7..d57cc6bde 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32; shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; +#define NUM_WARPS (BLOCK_SIZE / WARP) + #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; uint _ne1; #ifdef COOPMAT -shared uint _ne1_sh; +shared uvec4 ballots_sh[NUM_WARPS]; +void load_row_ids(uint expert_idx, bool nei0_is_pow2) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + } + barrier(); +} #endif #endif // MUL_MAT_ID -#define NUM_WARPS (BLOCK_SIZE / WARP) - #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif @@ -178,44 +236,11 @@ void main() { #ifdef MUL_MAT_ID #ifdef COOPMAT - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec2(ii0, ii1); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true); + } else { + load_row_ids(expert_idx, false); } - - barrier(); - - _ne1 = _ne1_sh; #else _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1; ii1++) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 29e4b5c9c..4d16eb079 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -19,6 +19,7 @@ #endif #include "types.comp" +#include "utils.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { }; uint _ne1; -shared uint _ne1_sh; +layout (constant_id = 5) const uint subgroup_size = 32; +shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { @@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem return elem; } +void load_row_ids(uint expert_idx, bool nei0_is_pow2) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); + } + _ne1 += total; + iter &= 15; + } + barrier(); +} #endif void main() { @@ -157,45 +217,12 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - // Spread the search across all elements in the first subgroup - if (gl_SubgroupID == 0) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_SubgroupInvocationID; - bool in_range = i < num_elements; - uint ii1 = i / p.nei0; - uint ii0 = i % p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - uint idx = subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx) { - row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); - } - _ne1 += subgroupBallotBitCount(ballot); - iter &= 15; - } - _ne1_sh = _ne1; + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true); + } else { + load_row_ids(expert_idx, false); } - barrier(); - - _ne1 = _ne1_sh; - // Workgroup has no work if (ic * BN >= _ne1) return; #endif From 0a9b43e507a359ca392c037cf341f55137ad0b69 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 23 Aug 2025 08:35:21 +0200 Subject: [PATCH 4/7] vulkan : support ggml_mean (#15393) * vulkan : support ggml_mean * vulkan : support sum, sum_rows and mean with non-contiguous tensors * vulkan : fix subbuffer size not accounting for misalign offset * tests : add backend-op tests for non-contiguous sum_rows * cuda : require contiguous src for SUM_ROWS, MEAN support * sycl : require contiguous src for SUM, SUM_ROWS, ARGSORT support * require ggml_contiguous_rows in supports_op and expect nb00=1 in the shader --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 82 +++++++++++++++++-- .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 43 ++++++++-- tests/test-backend-ops.cpp | 21 ++++- 5 files changed, 135 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d29a0b573..aa45ab39e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3485,11 +3485,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a0a650e92..12dd5dd2e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4391,10 +4391,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; - case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_POOL_2D: case GGML_OP_ACC: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2c5678f48..007556cf4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1015,6 +1015,39 @@ struct vk_op_upscale_push_constants { float sf0; float sf1; float sf2; float sf3; }; +struct vk_op_sum_rows_push_constants +{ + uint32_t n_cols; + uint32_t ne01, ne02; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; + float weight; + uint32_t misalign_offsets; + uint32_t ne0_12mp, ne0_12L; + uint32_t ne0_1mp, ne0_1L; +}; + +vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { + uint32_t type_size = (uint32_t)ggml_type_size(src->type); + vk_op_sum_rows_push_constants p = {}; + p.n_cols = (uint32_t)n_cols; + p.ne01 = (uint32_t)src->ne[1]; + p.ne02 = (uint32_t)src->ne[2]; + p.nb01 = (uint32_t)src->nb[1] / type_size; + p.nb02 = (uint32_t)src->nb[2] / type_size; + p.nb03 = (uint32_t)src->nb[3] / type_size; + p.nb11 = (uint32_t)dst->nb[1] / type_size; + p.nb12 = (uint32_t)dst->nb[2] / type_size; + p.nb13 = (uint32_t)dst->nb[3] / type_size; + p.weight = 1.0f; + return p; +} + +template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { + init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); + init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); +} + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -3128,7 +3161,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); @@ -7249,6 +7282,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_sum_rows_f32; } @@ -7387,6 +7421,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_CONV_2D_DW: case GGML_OP_IM2COL: case GGML_OP_SET_ROWS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return true; default: return false; @@ -7421,6 +7458,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -7571,10 +7618,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0); - y_sz = use_src1 ? ggml_nbytes(src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) : 0; - d_sz = ggml_nbytes(dst); + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); if (x_buf_offset + x_sz >= d_X->size) { x_sz = VK_WHOLE_SIZE; @@ -7602,6 +7649,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: { const uint32_t nr = ggml_nrows(src0); @@ -8588,11 +8636,19 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); } static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); +} + +static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -9815,6 +9871,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -9884,6 +9941,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -10087,6 +10145,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUM_ROWS: ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_MEAN: + ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_ARGMAX: ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); @@ -10246,6 +10308,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_ARGSORT: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -11483,8 +11546,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: @@ -12043,6 +12109,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_ARGMAX) { tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 961e5ffa1..759204afa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,9 +1,9 @@ #version 450 -#include "generic_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + + shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; + const float weight = p.weight; - tmp[col] = FLOAT_TYPE(0.0f); + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; - for (uint i = col; i < p.KX; i += BLOCK_SIZE) { - tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + tmp[col] = FLOAT_TYPE(0.0); + + for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); } barrier(); @@ -32,6 +65,6 @@ void main() { } if (col == 0) { - data_d[row] = D_TYPE(tmp[0]); + data_d[dst_idx] = D_TYPE(tmp[0] * weight); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a51527ca5..2e53f8e21 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4300,20 +4300,32 @@ struct test_sum : public test_case { struct test_sum_rows : public test_case { const ggml_type type; const std::array ne; + const bool permute; + const bool slice; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR4(type, ne, permute, slice); } test_sum_rows(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}) - : type(type), ne(ne) {} + std::array ne = {10, 5, 4, 3}, + bool permute = false, bool slice = false) + : type(type), ne(ne), permute(permute), slice(slice) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); + if (slice) { + a = ggml_view_4d(ctx, a, + ne[0], ne[1], ne[2] / 2, ne[3] - 1, + a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]); + } + if (permute) { + a = ggml_permute(ctx, a, 0, 2, 3, 1); + } + ggml_tensor * out = ggml_sum_rows(ctx, a); ggml_set_name(out, "out"); @@ -6195,6 +6207,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); From b55f06e1aa67fb10e89f53e31bbccf37eb2678ea Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Sat, 23 Aug 2025 14:58:57 +0800 Subject: [PATCH 5/7] vulkan.Dockerfile: install vulkan SDK using tarball (#15282) Signed-off-by: Xiaodong Ye --- .devops/vulkan.Dockerfile | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index fcd81ffa1..6cf87c67e 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -2,14 +2,30 @@ ARG UBUNTU_VERSION=24.04 FROM ubuntu:$UBUNTU_VERSION AS build -# Install build tools -RUN apt update && apt install -y git build-essential cmake wget +# Ref: https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html -# Install Vulkan SDK and cURL -RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \ - apt update -y && \ - apt-get install -y vulkan-sdk libcurl4-openssl-dev curl +# Install build tools +RUN apt update && apt install -y git build-essential cmake wget xz-utils + +# Install Vulkan SDK +ARG VULKAN_VERSION=1.4.321.1 +RUN ARCH=$(uname -m) && \ + wget -qO /tmp/vulkan-sdk.tar.xz https://sdk.lunarg.com/sdk/download/${VULKAN_VERSION}/linux/vulkan-sdk-linux-${ARCH}-${VULKAN_VERSION}.tar.xz && \ + mkdir -p /opt/vulkan && \ + tar -xf /tmp/vulkan-sdk.tar.xz -C /tmp --strip-components=1 && \ + mv /tmp/${ARCH}/* /opt/vulkan/ && \ + rm -rf /tmp/* + +# Install cURL and Vulkan SDK dependencies +RUN apt install -y libcurl4-openssl-dev curl \ + libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev + +# Set environment variables +ENV VULKAN_SDK=/opt/vulkan +ENV PATH=$VULKAN_SDK/bin:$PATH +ENV LD_LIBRARY_PATH=$VULKAN_SDK/lib:$LD_LIBRARY_PATH +ENV CMAKE_PREFIX_PATH=$VULKAN_SDK:$CMAKE_PREFIX_PATH +ENV PKG_CONFIG_PATH=$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH # Build it WORKDIR /app From 289bf4113ef5c02d8f5eb0cf2d86683d8b8bc4d9 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 Aug 2025 02:33:36 -0500 Subject: [PATCH 6/7] vulkan: Rewrite synchronization to allow some overlap between nodes (#15489) Track a list of nodes that need synchronization, and only sync if the new node depends on them (or overwrites them). This allows some overlap which can improve performance, and centralizes a big chunk of the synchronization logic. The remaining synchronization logic involves writes to memory other than the nodes, e.g. for dequantization or split_k. Each of these allocations has a bool indicating whether they were in use and need to be synced. This should be checked before they are written to, and set to true after they are done being consumed. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 233 ++++++++++++++++++++++----- 1 file changed, 193 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 007556cf4..c7cfb6473 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1231,6 +1231,14 @@ struct ggml_backend_vk_context { vk_pipeline_struct * prealloc_y_last_pipeline_used {}; const ggml_tensor * prealloc_y_last_tensor_used {}; + // Track which nodes have been used since the last sync, and whether they were written to + std::vector unsynced_nodes_written; + std::vector unsynced_nodes_read; + // Track which prealloc buffers have pending reads that need to be synchronized. + // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), + // and set to true after the buffer contents are consumed. + bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; + vk_buffer buffer_pool[MAX_VK_BUFFERS]; vk_context_ref compute_ctx; @@ -1906,14 +1914,18 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { return { buf, 0, VK_WHOLE_SIZE }; } -static void ggml_vk_sync_buffers(vk_context& ctx) { +static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->p->q->transfer_only; + const bool transfer_queue = subctx->p->q->transfer_only; - ctx->s->buffer.pipelineBarrier( - ctx->p->q->stage_flags, - ctx->p->q->stage_flags, + if (ctx) { + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + } + + subctx->s->buffer.pipelineBarrier( + subctx->p->q->stage_flags, + subctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -4898,7 +4910,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4913,7 +4925,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); VkBufferCopy buf_copy{ 0, offset, copy_size }; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { @@ -4967,7 +4979,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } } - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -4988,7 +5000,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz offset, copy_size}; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); if (width == spitch) { @@ -5068,7 +5080,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size if (buf != nullptr) { // Memory is pinned, use as staging buffer - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); return; @@ -5085,7 +5097,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size vk_buffer& staging_buffer = src->device->sync_staging; - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); @@ -5275,13 +5287,16 @@ static void ggml_vk_matmul( uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); - ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); return; } + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + GGML_ASSERT(batch_stride_d == m * n); // Round the split size up to a multiple of 256 (k-quant alignment) @@ -5291,9 +5306,10 @@ static void ggml_vk_matmul( const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); + ctx->prealloc_split_k_need_sync = true; } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { @@ -5338,7 +5354,6 @@ static void ggml_vk_matmul_id( "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); - ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); @@ -5469,8 +5484,8 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; init_pushconst_fastdiv(pc); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); } static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { @@ -5488,8 +5503,8 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + ggml_vk_sync_buffers(ctx, subctx); } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5684,12 +5699,23 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (y_non_contig || quantize_y) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || @@ -5728,6 +5754,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5874,6 +5907,17 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (y_non_contig) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); @@ -5917,10 +5961,16 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x, stride_batch_y, stride_batch_d, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -6007,7 +6057,6 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c workgroups_z /= gqa_ratio; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); } @@ -6094,7 +6143,6 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } @@ -6306,13 +6354,24 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (y_non_contig) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || @@ -6343,6 +6402,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { @@ -6502,6 +6568,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte GGML_ASSERT(qy_sz == y_sz); } + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (y_non_contig) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); @@ -6538,11 +6615,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), (uint32_t)nei0, (uint32_t)ne11, }; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, pc, { groups_x, (uint32_t)nei0, groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -6925,9 +7008,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; - ggml_vk_sync_buffers(subctx); - if (split_k > 1) { + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, @@ -6943,7 +7028,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // cancel out the divide by wg_denoms[0]. pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { @@ -6952,6 +7037,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -7820,7 +7906,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_y = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_SOFT_MAX) { // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer @@ -7838,7 +7923,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer @@ -7849,30 +7933,23 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co subbuf_z = { d_X, 0, x_sz }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_IM2COL) { // im2col uses only src1 and dst buffers - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { - ggml_vk_sync_buffers(subctx); // count_equal assumes that destination buffer is initialized with zeroes ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); - ggml_vk_sync_buffers(subctx); + ggml_vk_sync_buffers(ctx, subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_OPT_STEP_SGD) { // OPT_STEP_SGD works on src0, it does not need dst - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); } else if (use_src2) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else { - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } } @@ -7999,7 +8076,6 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, elements = { ne, 1, 1 }; } - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, @@ -8112,8 +8188,6 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; } - ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; @@ -8251,8 +8325,6 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; - ggml_vk_sync_buffers(subctx); - vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; @@ -9964,6 +10036,83 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } + if (!dryrun) { + // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers + // to synchronize them. This handles most "normal" synchronization when computing the graph, and when + // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers + // outside of this logic. When a node uses one of the prealloc buffers for something like + // dequantization or split_k, additional synchronization is needed between those passes. + bool need_sync = false; + + // Check whether "node" requires synchronization. The node requires synchronization if it + // overlaps in memory with another unsynchronized node and at least one of them is a write. + // Destination nodes are checked against both the written/read lists. Source nodes are only + // checked against the written list. Two nodes overlap in memory if they come from the same + // buffer and the tensor or view ranges overlap. + auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool { + if (unsynced_nodes.size() == 0) { + return false; + } + auto n_base = vk_tensor_offset(node) + node->view_offs; + auto n_size = ggml_nbytes(node); + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + for (auto &other : unsynced_nodes) { + ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; + vk_buffer o_buf = o_buf_ctx->dev_buffer; + if (a_buf == o_buf) { + auto o_base = vk_tensor_offset(other) + other->view_offs; + auto o_size = ggml_nbytes(other); + + if ((o_base <= n_base && n_base < o_base + o_size) || + (n_base <= o_base && o_base < n_base + n_size)) { + return true; + } + } + } + return false; + }; + + // For all fused ops, check if the destination node or any of the source + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + } + } + if (need_sync) { + VK_LOG_DEBUG("node_idx=" << i << " sync"); + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); + } else { + VK_LOG_DEBUG("node_idx=" << i << " unsynced"); + } + // Add the last fused node and all fused source nodes to the unsynchronized list. + const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ctx->unsynced_nodes_written.push_back(last_node); + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + ctx->unsynced_nodes_read.push_back(cur_node->src[j]); + } + } + } + switch (node->op) { case GGML_OP_REPEAT: ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); @@ -10427,6 +10576,10 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->gc.temp_buffers.clear(); ctx->prealloc_y_last_pipeline_used = {}; + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); From 21dc4ddaf21b8ed551d717e7606abd2cffbacdbf Mon Sep 17 00:00:00 2001 From: LaffeyNyaa <112215776+LaffeyNyaa@users.noreply.github.com> Date: Sat, 23 Aug 2025 16:38:30 +0800 Subject: [PATCH 7/7] chat : fix debug build assertion in trim function (#15520) --- src/llama-chat.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0a96a9a57..4d6fdf822 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -16,10 +16,10 @@ static std::string trim(const std::string & str) { size_t start = 0; size_t end = str.size(); - while (start < end && isspace(str[start])) { + while (start < end && isspace(static_cast(str[start]))) { start += 1; } - while (end > start && isspace(str[end - 1])) { + while (end > start && isspace(static_cast(str[end - 1]))) { end -= 1; } return str.substr(start, end - start);