vulkan: use GL_NV_cooperative_matrix_decode_vector for faster matmul (#23541)

This commit is contained in:
Jeff Bolz 2026-05-27 10:18:28 -05:00 committed by GitHub
parent 837bb6b447
commit b36eefc1b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 865 additions and 6 deletions

View file

@ -79,6 +79,12 @@ if (Vulkan_FOUND)
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix_decode_vector"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp"
"GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"

View file

@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#include <vulkan/vulkan.hpp>
// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the
// installed Vulkan headers predate the extension.
#ifndef VK_NV_cooperative_matrix_decode_vector
#define VK_NV_cooperative_matrix_decode_vector 1
#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector"
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000)
typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
VkStructureType sType;
void* pNext;
VkBool32 cooperativeMatrixDecodeVector;
} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV;
#endif
// SPIR-V Headers: different SDK installations expose different include paths.
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
@ -678,6 +691,7 @@ struct vk_device_struct {
uint32_t coopmat_int_k;
bool coopmat2;
bool coopmat2_decode_vector;
bool pipeline_executable_properties_support {};
@ -2167,6 +2181,136 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it
// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the
// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the
// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction.
// Returns true when the input used the extension (and `out` was populated with a
// stripped copy); returns false otherwise without touching `out`.
static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector";
if (word_count < 5) {
return false;
}
bool uses_decode_vector = false;
for (size_t pos = 5; pos < word_count; ) {
uint32_t word = code[pos];
uint32_t wc = word >> spv::WordCountShift;
uint32_t op = word & spv::OpCodeMask;
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
if (op == spv::OpExtension && wc >= 2) {
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
if (strcmp(s, kDecodeVectorExt) == 0) {
uses_decode_vector = true;
break;
}
}
pos += wc;
}
if (!uses_decode_vector) {
return false;
}
VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector");
// Bulk-copy unchanged runs and only break the run when an instruction needs to
// be dropped or patched. Use reserve + insert/push_back so the destination buffer
// is touched exactly once (no zero-initialization pass from resize()).
out.clear();
out.reserve(word_count);
size_t run_start = 0;
auto flush_run = [&](size_t up_to) {
if (up_to > run_start) {
out.insert(out.end(), code + run_start, code + up_to);
}
};
for (size_t pos = 5; pos < word_count; ) {
uint32_t word = code[pos];
uint32_t wc = word >> spv::WordCountShift;
uint32_t op = word & spv::OpCodeMask;
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
if (op == spv::OpExtension && wc >= 2) {
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
if (strcmp(s, kDecodeVectorExt) == 0) {
flush_run(pos);
pos += wc;
run_start = pos;
continue;
}
}
if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) {
flush_run(pos);
pos += wc;
run_start = pos;
continue;
}
if (op == kSpvOpCooperativeMatrixLoadTensorNV) {
// [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...]
GGML_ASSERT(wc >= 8);
uint32_t mem_mask = code[pos + 6];
size_t cur = pos + 7;
// Each of these MemoryAccess bits (when set) carries one trailing operand.
cur += (mem_mask & 0x2) ? 1 : 0; // Aligned
cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable
cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible
cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask
cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask
GGML_ASSERT(cur < pos + wc);
uint32_t ta_mask = code[cur];
if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) {
pos += wc;
continue; // leave instruction inside the current unchanged run
}
flush_run(pos);
// Append unchanged prefix of the instruction (header through the mem-extras).
size_t inst_start = out.size();
size_t pre_n = cur - pos;
out.insert(out.end(), code + pos, code + pos + pre_n);
// Emit TA mask with the DecodeVectorFunc bit cleared.
out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit);
// TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim;
// DecodeVectorFunc (0x4) is dropped along with its trailing id operand.
size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0);
if (keep_ta_extras) {
out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras);
}
GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1);
// Patch the instruction header with the new (one-shorter) word count.
uint32_t new_wc = wc - 1;
out[inst_start] = (new_wc << spv::WordCountShift) | op;
pos += wc;
run_start = pos;
continue;
}
pos += wc;
}
flush_run(word_count);
return true;
}
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
@ -2238,6 +2382,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
}
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
if (device->coopmat2 && !device->coopmat2_decode_vector) {
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
std::vector<uint32_t> stripped;
if (ggml_vk_strip_decode_vector(src, src_n, stripped)) {
spirv = std::move(stripped);
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
}
}
#endif
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
vk::PushConstantRange pcr(
@ -5159,6 +5315,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool amd_shader_core_properties2 = false;
bool pipeline_robustness = false;
bool coopmat2_support = false;
bool coopmat2_decode_vector_support = false;
bool pipeline_executable_properties_support = false;
device->coopmat_support = false;
device->integer_dot_product = false;
@ -5193,6 +5350,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#endif
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
coopmat2_decode_vector_support = true;
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
@ -5470,6 +5630,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
#endif
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
if (coopmat2_decode_vector_support) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME);
}
#if defined(VK_KHR_shader_bfloat16)
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
bfloat16_features.pNext = nullptr;
@ -5629,6 +5797,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
found_fp32_128 && found_fp32_256 &&
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
device->coopmat2 = true;
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
}
}
#endif
@ -5915,6 +6084,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool fp16_compute = false;
bool coopmat_support = false;
bool coopmat2_support = false;
bool coopmat2_decode_vector_support = false;
bool integer_dot_product = false;
bool bfloat16_support = false;
@ -5933,6 +6103,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#endif
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
coopmat2_decode_vector_support = true;
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
@ -6017,6 +6190,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
}
#endif
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
if (coopmat2_decode_vector_support) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
}
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
@ -6041,7 +6221,14 @@ static void ggml_vk_print_gpu_info(size_t idx) {
#endif
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
coopmat2_decode_vector_support = false;
#endif
std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2")
: coopmat_support ? "KHR_coopmat"
: "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",

View file

@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 glslc support")
endif()
if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 decode_vector glslc support")
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
message(STATUS "Enabling dot glslc support")

View file

@ -1,4 +1,12 @@
// Each format defines a scalar dequantFunc<T> plus a V=4 dequantFunc<T>_v
// passed as the optional vector decoder to coopMatLoadTensorNV via
// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support
// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V.
#ifdef GL_NV_cooperative_matrix_decode_vector
#extension GL_NV_cooperative_matrix_decode_vector : enable
#endif
#include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2
return bit != 0u ? d : -d;
}
f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t md = -d;
const uint idx = coordInBlock[1];
const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u);
return f16vec4(
(qs_nib & 1u) != 0u ? d : md,
(qs_nib & 2u) != 0u ? d : md,
(qs_nib & 4u) != 0u ? d : md,
(qs_nib & 8u) != 0u ? d : md);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2; // 0 or 4
const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6}
const uint qsw = uint32_t(bl.block.qs[qs_i ])
| (uint32_t(bl.block.qs[qs_i + 1u]) << 16);
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu;
const u8vec4 q = unpack8(q4);
return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d)));
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 {
block_q4_1_packed32 block;
};
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl);
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2; // 0 or 4
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m)));
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 {
block_q5_0_packed16 block;
};
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2; // 0 or 4
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]);
const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d)));
}
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 {
block_q5_1_packed32 block;
};
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl);
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2; // 0 or 4
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m)));
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint base = idx >> 1u;
const uint w = uint(uint16_t(bl.block.qs[base]))
| (uint(uint16_t(bl.block.qs[base + 1u])) << 16u);
const i8vec4 qi = unpack8(int32_t(w));
return f16vec4(vec4(qi) * vec4(float(d)));
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};
@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
block_q2_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 {
block_q2_K_packed32 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl);
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = idx >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
// qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0,
// so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2).
const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2);
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
const uint qs4 = (qsw >> qsshift) & 0x03030303u;
const u8vec4 qi = unpack8(qs4);
const uint scales = bl.block.scales[scalesi];
const float16_t d_sub = dm.x * float16_t(scales & 0xF);
const float16_t m_sub = dm.y * float16_t(scales >> 4);
return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub)));
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 {
block_q3_K_packed16 block;
};
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint n = idx >> 7; // 0,1
const uint is = idx >> 4; // 0..15
const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3
const uint qsshift = halfsplit << 1; // 0,2,4,6
const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte)
uint32_t scaleidx0 = (is < 8) ? is : (is - 8);
uint32_t scaleidx0shift = (is < 8) ? 0u : 4u;
uint32_t scaleidx1 = is + 8 - (is / 4) * 4;
uint32_t scaleidx1shift = (is / 4) * 2;
const int8_t us = int8_t(
((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) |
(((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
const float16_t dl = bl.block.d * float16_t(int(us) - 32);
// For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4.
const uint qsi = (n << 5) + (idx & 0x1Cu);
const uint hmi = (idx & 0x1Cu);
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
// After this: byte j of qsw / hmw holds the data for element idx+j.
const uint qsw = uint32_t(bl16.block.qs[qsi >> 1])
| (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16);
const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1])
| (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16);
// qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits
// with no inter-byte leakage.
const uint ql4 = (qsw >> qsshift) & 0x03030303u;
const uint qh4 = (hmw >> hbit) & 0x01010101u;
const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4);
return f16vec4(vec4(q) * vec4(float(dl)));
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};
@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 {
block_q4_K_packed32 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
return float16_t(ret);
}
f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint is = idx >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
// idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index:
// (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a
// single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage.
const uint sh = (idx & 0x20u) >> 3u;
const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]);
const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu);
return f16vec4(vec4(d) * vec4(q) - vec4(m));
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};
@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
block_q5_K_packed128 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 {
block_q5_K_packed32 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
return float16_t(ret);
}
f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint is = idx >> 5;
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
// sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage).
const uint sh = (idx & 0x20u) >> 3u;
const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2);
const uint qh_w = (idx & 0x1Eu) >> 2;
const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu;
// qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4.
const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u;
const u8vec4 qi = unpack8(ql4 | qh4);
return f16vec4(vec4(qi) * vec4(d) - vec4(m));
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};
@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
return ret;
}
f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x40) >> 6;
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = idx >> 4;
const uint sh = b * 4; // 0 or 4
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1);
const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1);
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
// After this: byte j of qlw / qhw holds the data for element idx+j.
const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16);
const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16);
// sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the
// wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position.
const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu;
const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u;
const ivec4 qi = ivec4(unpack8(ql4 | qh4));
return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale)));
}
#if defined(DATA_A_IQ1_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
block_iq1_s block;
@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
return ret;
}
f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = idx >> 5;
const uint ib8 = idx >> 3;
const int i8b = int(idx & 4); // 0 or 4
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
const ivec4 q = ivec4(
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
return f16vec4((vec4(q) + vec4(delta)) * dl);
}
#endif
#if defined(DATA_A_IQ1_M)
@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
return ret;
}
f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = idx >> 3;
const uint ib16 = idx >> 4;
const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0;
const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)];
const ivec4 q = ivec4(
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl));
}
#endif
#if defined(DATA_A_IQ2_XXS)
@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
const uint idx = coordInBlock[1];
const uint ib32 = idx >> 5;
const uint ib8 = (idx & 0x18) >> 3;
const uint iqs = 8 * ib32 + ib8;
const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
const uint sb = sign >> (idx & 7u);
const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
const u8vec4 g = unpack8(g2);
return f16vec4(
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
}
#endif
#if defined(DATA_A_IQ2_XS)
@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint is = idx >> 5;
const uint sshift = (idx & 0x10) >> 2;
const uint iqs = idx >> 3;
const uint16_t qs = bl.block.qs[iqs];
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
const uint sb = sign >> (idx & 7u);
const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
const u8vec4 g = unpack8(g2);
return f16vec4(
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
}
#endif
#if defined(DATA_A_IQ2_S)
@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
return float16_t(v[idx & 1]);
}
f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint ib32 = idx >> 5;
const uint ib8 = idx >> 3;
const uint qhshift = 2 * (ib8 % 4);
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u);
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
const u8vec4 g = unpack8(g2);
return f16vec4(
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
}
#endif
#if defined(DATA_A_IQ3_XXS)
@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
const uint idx = coordInBlock[1];
const uint iqs = idx >> 2;
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1]));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u);
const uint grid = iq3xxs_grid[qs];
const u8vec4 g = unpack8(grid);
return f16vec4(
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
}
#endif
#if defined(DATA_A_IQ3_S)
@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
return float16_t(v[idx & 1]);
}
f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx >> 2;
const uint iqh = idx >> 5;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u);
const uint scale = bl.block.scales[iqs / 16];
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const u8vec4 g = unpack8(grid);
return f16vec4(
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
}
#endif
#if defined(DATA_A_IQ4_XS)
@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
block_iq4_xs block;
};
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 {
block_iq4_xs_packed32 block;
};
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
return ret;
}
f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = idx >> 5; // 0..7
const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF;
const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3;
const uint qshift = (idx & 0x10) >> 2; // {0, 4}
const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32)
const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32);
const uint qsw = bl32.block.qs[qs_w];
const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu);
const vec4 ret = vec4(
float(kvalues_iq4nl[qv.x]),
float(kvalues_iq4nl[qv.y]),
float(kvalues_iq4nl[qv.z]),
float(kvalues_iq4nl[qv.w])) * float(dl);
return f16vec4(ret);
}
#endif
#if defined(DATA_A_IQ4_NL)
@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
block_iq4_nl block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 {
block_iq4_nl_packed16 block;
};
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2; // 0 or 4
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
return f16vec4(
float(d) * float(kvalues_iq4nl[q.x]),
float(d) * float(kvalues_iq4nl[q.y]),
float(d) * float(kvalues_iq4nl[q.z]),
float(d) * float(kvalues_iq4nl[q.w]));
}
#endif
#if defined(DATA_A_MXFP4)
@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uvec4 qv = uvec4(
uint(bl.block.qs[iqs]),
uint(bl.block.qs[iqs + 1u]),
uint(bl.block.qs[iqs + 2u]),
uint(bl.block.qs[iqs + 3u]));
qv = (qv >> shift) & 0xFu;
const vec4 ret = vec4(
float(kvalues_mxfp4[qv.x]),
float(kvalues_mxfp4[qv.y]),
float(kvalues_mxfp4[qv.z]),
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
return f16vec4(ret);
}
#endif
#if defined(DATA_A_NVFP4)
@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF
block_nvfp4 block;
};
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 {
block_nvfp4_packed32 block;
};
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords
qs = (qs >> shift) & 0xF;
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
}
f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl);
const uint idx = coordInBlock[1];
const uint sub = idx >> 4;
const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8)
const uint shift = (idx & 0x8) >> 1;
const float d = ue4m3_to_fp32(bl.block.d[sub]);
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
const vec4 ret = vec4(
float(kvalues_mxfp4[qv.x]),
float(kvalues_mxfp4[qv.y]),
float(kvalues_mxfp4[qv.z]),
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
return f16vec4(ret);
}
#endif
#if defined(DATA_A_Q1_0)
#define dequantFuncA dequantFuncQ1_0
#define dequantFuncA_v dequantFuncQ1_0_v
#elif defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#define dequantFuncA_v dequantFuncQ4_0_v
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#define dequantFuncA_v dequantFuncQ4_1_v
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#define dequantFuncA_v dequantFuncQ5_0_v
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#define dequantFuncA_v dequantFuncQ5_1_v
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#define dequantFuncA_v dequantFuncQ8_0_v
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#define dequantFuncA_v dequantFuncQ2_K_v
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#define dequantFuncA_v dequantFuncQ3_K_v
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define dequantFuncA_v dequantFuncQ4_K_v
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define dequantFuncA_v dequantFuncQ5_K_v
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#define dequantFuncA_v dequantFuncQ6_K_v
#elif defined(DATA_A_IQ1_S)
#define dequantFuncA dequantFuncIQ1_S
#define dequantFuncA_v dequantFuncIQ1_S_v
#elif defined(DATA_A_IQ1_M)
#define dequantFuncA dequantFuncIQ1_M
#define dequantFuncA_v dequantFuncIQ1_M_v
#elif defined(DATA_A_IQ2_XXS)
#define dequantFuncA dequantFuncIQ2_XXS
#define dequantFuncA_v dequantFuncIQ2_XXS_v
#elif defined(DATA_A_IQ2_XS)
#define dequantFuncA dequantFuncIQ2_XS
#define dequantFuncA_v dequantFuncIQ2_XS_v
#elif defined(DATA_A_IQ2_S)
#define dequantFuncA dequantFuncIQ2_S
#define dequantFuncA_v dequantFuncIQ2_S_v
#elif defined(DATA_A_IQ3_XXS)
#define dequantFuncA dequantFuncIQ3_XXS
#define dequantFuncA_v dequantFuncIQ3_XXS_v
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#define dequantFuncA_v dequantFuncIQ3_S_v
#elif defined(DATA_A_IQ4_XS)
#define dequantFuncA dequantFuncIQ4_XS
#define dequantFuncA_v dequantFuncIQ4_XS_v
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#define dequantFuncA_v dequantFuncIQ4_NL_v
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#define dequantFuncA_v dequantFuncMXFP4_v
#elif defined(DATA_A_NVFP4)
#define dequantFuncA dequantFuncNVFP4
#define dequantFuncA_v dequantFuncNVFP4_v
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif

View file

@ -0,0 +1,7 @@
#version 460
#extension GL_NV_cooperative_matrix_decode_vector : require
void main()
{
}

View file

@ -11,6 +11,9 @@
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#ifdef GL_NV_cooperative_matrix_decode_vector
#extension GL_NV_cooperative_matrix_decode_vector : enable
#endif
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
@ -54,6 +57,41 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const
}
}
// V=4 vector decode for K/V; dispatches to per-format _v decoders.
f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeK) {
case 0u: return f16vec4(decodeBufF32(bl_in).block);
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return f16vec4(0);
}
}
f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
switch (FaTypeV) {
case 0u: return f16vec4(decodeBufF32(bl_in).block);
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
default: return f16vec4(0);
}
}
#ifdef GL_NV_cooperative_matrix_decode_vector
#define FADECODEK , faDecodeK, faDecodeKVector
#define FADECODEV , faDecodeV, faDecodeVVector
#else
#define FADECODEK , faDecodeK
#define FADECODEV , faDecodeV
#endif
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
@ -259,7 +297,7 @@ void main() {
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
const bool k_use_decode = (bs_k > 1u);
if (k_use_decode) {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK);
} else {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
}
@ -325,7 +363,7 @@ void main() {
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
const bool v_use_decode = (bs_v > 1u);
if (v_use_decode) {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV);
} else {
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
}

View file

@ -71,10 +71,12 @@ layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#if QUANT_K > 1
#define DECODEFUNCA , dequantFuncA
#include "dequant_funcs_cm2.glsl"
#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector)
#define DECODEFUNCA , dequantFuncA, dequantFuncA_v
#else
#define DECODEFUNCA , dequantFuncA
#endif
#else
#define DECODEFUNCA
#endif

View file

@ -1722,11 +1722,18 @@ struct block_nvfp4
uint8_t qs[QUANT_K_NVFP4 / 2];
};
struct block_nvfp4_packed32
{
uint32_t d[QUANT_K_NVFP4 / 16 / 4];
uint32_t qs[QUANT_K_NVFP4 / 2 / 4];
};
#if defined(DATA_A_NVFP4)
#define QUANT_K QUANT_K_NVFP4
#define QUANT_R QUANT_R_NVFP4
#define QUANT_AUXF 1
#define A_TYPE block_nvfp4
#define A_TYPE_PACKED32 block_nvfp4_packed32
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)