mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-30 20:33:39 +00:00
vulkan: use GL_NV_cooperative_matrix_decode_vector for faster matmul (#23541)
This commit is contained in:
parent
837bb6b447
commit
b36eefc1b3
8 changed files with 865 additions and 6 deletions
|
|
@ -79,6 +79,12 @@ if (Vulkan_FOUND)
|
|||
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_NV_cooperative_matrix_decode_vector"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp"
|
||||
"GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_integer_dot_product"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the
|
||||
// installed Vulkan headers predate the extension.
|
||||
#ifndef VK_NV_cooperative_matrix_decode_vector
|
||||
#define VK_NV_cooperative_matrix_decode_vector 1
|
||||
#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector"
|
||||
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000)
|
||||
typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
|
||||
VkStructureType sType;
|
||||
void* pNext;
|
||||
VkBool32 cooperativeMatrixDecodeVector;
|
||||
} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV;
|
||||
#endif
|
||||
|
||||
// SPIR-V Headers: different SDK installations expose different include paths.
|
||||
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
|
||||
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
|
||||
|
|
@ -678,6 +691,7 @@ struct vk_device_struct {
|
|||
uint32_t coopmat_int_k;
|
||||
|
||||
bool coopmat2;
|
||||
bool coopmat2_decode_vector;
|
||||
|
||||
bool pipeline_executable_properties_support {};
|
||||
|
||||
|
|
@ -2167,6 +2181,136 @@ static uint32_t compile_count = 0;
|
|||
static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
|
||||
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
|
||||
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
|
||||
|
||||
// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it
|
||||
// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the
|
||||
// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the
|
||||
// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction.
|
||||
// Returns true when the input used the extension (and `out` was populated with a
|
||||
// stripped copy); returns false otherwise without touching `out`.
|
||||
static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
|
||||
static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector";
|
||||
|
||||
if (word_count < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool uses_decode_vector = false;
|
||||
for (size_t pos = 5; pos < word_count; ) {
|
||||
uint32_t word = code[pos];
|
||||
uint32_t wc = word >> spv::WordCountShift;
|
||||
uint32_t op = word & spv::OpCodeMask;
|
||||
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
||||
if (op == spv::OpExtension && wc >= 2) {
|
||||
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
||||
if (strcmp(s, kDecodeVectorExt) == 0) {
|
||||
uses_decode_vector = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
pos += wc;
|
||||
}
|
||||
|
||||
if (!uses_decode_vector) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector");
|
||||
|
||||
// Bulk-copy unchanged runs and only break the run when an instruction needs to
|
||||
// be dropped or patched. Use reserve + insert/push_back so the destination buffer
|
||||
// is touched exactly once (no zero-initialization pass from resize()).
|
||||
out.clear();
|
||||
out.reserve(word_count);
|
||||
|
||||
size_t run_start = 0;
|
||||
auto flush_run = [&](size_t up_to) {
|
||||
if (up_to > run_start) {
|
||||
out.insert(out.end(), code + run_start, code + up_to);
|
||||
}
|
||||
};
|
||||
|
||||
for (size_t pos = 5; pos < word_count; ) {
|
||||
uint32_t word = code[pos];
|
||||
uint32_t wc = word >> spv::WordCountShift;
|
||||
uint32_t op = word & spv::OpCodeMask;
|
||||
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
||||
|
||||
if (op == spv::OpExtension && wc >= 2) {
|
||||
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
||||
if (strcmp(s, kDecodeVectorExt) == 0) {
|
||||
flush_run(pos);
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) {
|
||||
flush_run(pos);
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op == kSpvOpCooperativeMatrixLoadTensorNV) {
|
||||
// [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...]
|
||||
GGML_ASSERT(wc >= 8);
|
||||
|
||||
uint32_t mem_mask = code[pos + 6];
|
||||
size_t cur = pos + 7;
|
||||
// Each of these MemoryAccess bits (when set) carries one trailing operand.
|
||||
cur += (mem_mask & 0x2) ? 1 : 0; // Aligned
|
||||
cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable
|
||||
cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible
|
||||
cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask
|
||||
cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask
|
||||
GGML_ASSERT(cur < pos + wc);
|
||||
|
||||
uint32_t ta_mask = code[cur];
|
||||
if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) {
|
||||
pos += wc;
|
||||
continue; // leave instruction inside the current unchanged run
|
||||
}
|
||||
|
||||
flush_run(pos);
|
||||
|
||||
// Append unchanged prefix of the instruction (header through the mem-extras).
|
||||
size_t inst_start = out.size();
|
||||
size_t pre_n = cur - pos;
|
||||
out.insert(out.end(), code + pos, code + pos + pre_n);
|
||||
|
||||
// Emit TA mask with the DecodeVectorFunc bit cleared.
|
||||
out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit);
|
||||
|
||||
// TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim;
|
||||
// DecodeVectorFunc (0x4) is dropped along with its trailing id operand.
|
||||
size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0);
|
||||
if (keep_ta_extras) {
|
||||
out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras);
|
||||
}
|
||||
|
||||
GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1);
|
||||
|
||||
// Patch the instruction header with the new (one-shorter) word count.
|
||||
uint32_t new_wc = wc - 1;
|
||||
out[inst_start] = (new_wc << spv::WordCountShift) | op;
|
||||
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
pos += wc;
|
||||
}
|
||||
|
||||
flush_run(word_count);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
||||
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
||||
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
|
||||
|
|
@ -2238,6 +2382,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
if (device->coopmat2 && !device->coopmat2_decode_vector) {
|
||||
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
|
||||
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
|
||||
std::vector<uint32_t> stripped;
|
||||
if (ggml_vk_strip_decode_vector(src, src_n, stripped)) {
|
||||
spirv = std::move(stripped);
|
||||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
||||
|
||||
vk::PushConstantRange pcr(
|
||||
|
|
@ -5159,6 +5315,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
bool amd_shader_core_properties2 = false;
|
||||
bool pipeline_robustness = false;
|
||||
bool coopmat2_support = false;
|
||||
bool coopmat2_decode_vector_support = false;
|
||||
bool pipeline_executable_properties_support = false;
|
||||
device->coopmat_support = false;
|
||||
device->integer_dot_product = false;
|
||||
|
|
@ -5193,6 +5350,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
||||
coopmat2_decode_vector_support = true;
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
|
|
@ -5470,6 +5630,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
||||
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
||||
if (coopmat2_decode_vector_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME);
|
||||
}
|
||||
|
||||
#if defined(VK_KHR_shader_bfloat16)
|
||||
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
||||
bfloat16_features.pNext = nullptr;
|
||||
|
|
@ -5629,6 +5797,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
found_fp32_128 && found_fp32_256 &&
|
||||
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
||||
device->coopmat2 = true;
|
||||
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -5915,6 +6084,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
bool fp16_compute = false;
|
||||
bool coopmat_support = false;
|
||||
bool coopmat2_support = false;
|
||||
bool coopmat2_decode_vector_support = false;
|
||||
bool integer_dot_product = false;
|
||||
bool bfloat16_support = false;
|
||||
|
||||
|
|
@ -5933,6 +6103,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
||||
coopmat2_decode_vector_support = true;
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
|
|
@ -6017,6 +6190,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
||||
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
||||
if (coopmat2_decode_vector_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
||||
|
||||
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||
|
|
@ -6041,7 +6221,14 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
#endif
|
||||
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
||||
|
||||
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
||||
coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
||||
#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
coopmat2_decode_vector_support = false;
|
||||
#endif
|
||||
|
||||
std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2")
|
||||
: coopmat_support ? "KHR_coopmat"
|
||||
: "none";
|
||||
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 decode_vector glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling dot glslc support")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
|
||||
// Each format defines a scalar dequantFunc<T> plus a V=4 dequantFunc<T>_v
|
||||
// passed as the optional vector decoder to coopMatLoadTensorNV via
|
||||
// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support
|
||||
// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V.
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||
|
|
@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2
|
|||
return bit != 0u ? d : -d;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t md = -d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u);
|
||||
return f16vec4(
|
||||
(qs_nib & 1u) != 0u ? d : md,
|
||||
(qs_nib & 2u) != 0u ? d : md,
|
||||
(qs_nib & 4u) != 0u ? d : md,
|
||||
(qs_nib & 8u) != 0u ? d : md);
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
|
|
@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl.block.qs[qs_i ])
|
||||
| (uint32_t(bl.block.qs[qs_i + 1u]) << 16);
|
||||
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
|
||||
const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu;
|
||||
const u8vec4 q = unpack8(q4);
|
||||
return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
|
||||
block_q4_1 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 {
|
||||
block_q4_1_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
|
|
@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
|
||||
block_q5_0 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 {
|
||||
block_q5_0_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
|
|
@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
|
||||
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
|
||||
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
|
||||
const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]);
|
||||
const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits
|
||||
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
|
||||
|
||||
return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
|
||||
block_q5_1 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 {
|
||||
block_q5_1_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
|
|
@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
|
||||
const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits
|
||||
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
|
||||
|
||||
return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
|
||||
block_q8_0_packed16 block;
|
||||
};
|
||||
|
|
@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint base = idx >> 1u;
|
||||
const uint w = uint(uint16_t(bl.block.qs[base]))
|
||||
| (uint(uint16_t(bl.block.qs[base + 1u])) << 16u);
|
||||
const i8vec4 qi = unpack8(int32_t(w));
|
||||
return f16vec4(vec4(qi) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
|
||||
block_q2_K block;
|
||||
};
|
||||
|
|
@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
|||
block_q2_K_packed16 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 {
|
||||
block_q2_K_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||
|
|
@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl);
|
||||
const f16vec2 dm = bl.block.dm;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint scalesi = idx >> 4; // 0..15
|
||||
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
|
||||
|
||||
// qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0,
|
||||
// so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2).
|
||||
const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2);
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const uint qs4 = (qsw >> qsshift) & 0x03030303u;
|
||||
const u8vec4 qi = unpack8(qs4);
|
||||
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
const float16_t d_sub = dm.x * float16_t(scales & 0xF);
|
||||
const float16_t m_sub = dm.y * float16_t(scales >> 4);
|
||||
return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
|
||||
block_q3_K block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 {
|
||||
block_q3_K_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
|
@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint n = idx >> 7; // 0,1
|
||||
const uint is = idx >> 4; // 0..15
|
||||
const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3
|
||||
const uint qsshift = halfsplit << 1; // 0,2,4,6
|
||||
const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte)
|
||||
|
||||
uint32_t scaleidx0 = (is < 8) ? is : (is - 8);
|
||||
uint32_t scaleidx0shift = (is < 8) ? 0u : 4u;
|
||||
uint32_t scaleidx1 = is + 8 - (is / 4) * 4;
|
||||
uint32_t scaleidx1shift = (is / 4) * 2;
|
||||
|
||||
const int8_t us = int8_t(
|
||||
((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) |
|
||||
(((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
|
||||
const float16_t dl = bl.block.d * float16_t(int(us) - 32);
|
||||
|
||||
// For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4.
|
||||
const uint qsi = (n << 5) + (idx & 0x1Cu);
|
||||
const uint hmi = (idx & 0x1Cu);
|
||||
|
||||
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
|
||||
// After this: byte j of qsw / hmw holds the data for element idx+j.
|
||||
const uint qsw = uint32_t(bl16.block.qs[qsi >> 1])
|
||||
| (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16);
|
||||
const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1])
|
||||
| (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16);
|
||||
|
||||
// qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits
|
||||
// with no inter-byte leakage.
|
||||
const uint ql4 = (qsw >> qsshift) & 0x03030303u;
|
||||
const uint qh4 = (hmw >> hbit) & 0x01010101u;
|
||||
|
||||
const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4);
|
||||
return f16vec4(vec4(q) * vec4(float(dl)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
|
||||
block_q4_K block;
|
||||
};
|
||||
|
|
@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
|||
block_q4_K_packed16 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 {
|
||||
block_q4_K_packed32 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
|
||||
block_q4_K_packed128 block;
|
||||
};
|
||||
|
|
@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
|||
return float16_t(ret);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl);
|
||||
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint is = idx >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q4k[0];
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
#endif
|
||||
|
||||
// idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index:
|
||||
// (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a
|
||||
// single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage.
|
||||
const uint sh = (idx & 0x20u) >> 3u;
|
||||
const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]);
|
||||
const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu);
|
||||
|
||||
return f16vec4(vec4(d) * vec4(q) - vec4(m));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
||||
block_q5_K block;
|
||||
};
|
||||
|
|
@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
|
|||
block_q5_K_packed128 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 {
|
||||
block_q5_K_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
|
||||
|
|
@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
|||
return float16_t(ret);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl);
|
||||
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint is = idx >> 5;
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q5k[0];
|
||||
|
||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
#endif
|
||||
|
||||
// sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage).
|
||||
const uint sh = (idx & 0x20u) >> 3u;
|
||||
const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2);
|
||||
const uint qh_w = (idx & 0x1Eu) >> 2;
|
||||
|
||||
const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu;
|
||||
// qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4.
|
||||
const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u;
|
||||
|
||||
const u8vec4 qi = unpack8(ql4 | qh4);
|
||||
return f16vec4(vec4(qi) * vec4(d) - vec4(m));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
||||
block_q6_K block;
|
||||
};
|
||||
|
|
@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
|
|||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint b = (idx & 0x40) >> 6;
|
||||
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
|
||||
const uint is = idx >> 4;
|
||||
const uint sh = b * 4; // 0 or 4
|
||||
|
||||
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
||||
|
||||
const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1);
|
||||
const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1);
|
||||
|
||||
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
|
||||
// After this: byte j of qlw / qhw holds the data for element idx+j.
|
||||
const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16);
|
||||
const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16);
|
||||
|
||||
// sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the
|
||||
// wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position.
|
||||
const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu;
|
||||
const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u;
|
||||
|
||||
const ivec4 qi = ivec4(unpack8(ql4 | qh4));
|
||||
return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale)));
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
|
||||
block_iq1_s block;
|
||||
|
|
@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
|
|||
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = idx >> 3;
|
||||
const int i8b = int(idx & 4); // 0 or 4
|
||||
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
|
||||
|
||||
const ivec4 q = ivec4(
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
|
||||
return f16vec4((vec4(q) + vec4(delta)) * dl);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
|
|
@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords
|
|||
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
uvec2 scales = unpack32(bl64.block.scales);
|
||||
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
|
||||
|
||||
const uint ib8 = idx >> 3;
|
||||
const uint ib16 = idx >> 4;
|
||||
const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group
|
||||
|
||||
const uint sc = bl.block.scales[ib8 / 8];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0;
|
||||
const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)];
|
||||
|
||||
const ivec4 q = ivec4(
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
|
||||
return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
|
|
@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
|
|||
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
return float16_t(ret[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = (idx & 0x18) >> 3;
|
||||
const uint iqs = 8 * ib32 + ib8;
|
||||
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
|
||||
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
|
||||
|
||||
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
|
||||
sign |= bitCount(sign) << 7;
|
||||
const uint sb = sign >> (idx & 7u);
|
||||
|
||||
const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XS)
|
||||
|
|
@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
|
|||
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
return float16_t(ret[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint is = idx >> 5;
|
||||
const uint sshift = (idx & 0x10) >> 2;
|
||||
const uint iqs = idx >> 3;
|
||||
|
||||
const uint16_t qs = bl.block.qs[iqs];
|
||||
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
|
||||
|
||||
uint sign = uint(qs >> 9);
|
||||
sign |= bitCount(sign) << 7;
|
||||
const uint sb = sign >> (idx & 7u);
|
||||
|
||||
const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_S)
|
||||
|
|
@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords
|
|||
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = idx >> 3;
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
|
||||
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u);
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
|
||||
const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_XXS)
|
||||
|
|
@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo
|
|||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint iqs = idx >> 2;
|
||||
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1]));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u);
|
||||
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const u8vec4 g = unpack8(grid);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_S)
|
||||
|
|
@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
|
|||
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint iqs = idx >> 2;
|
||||
const uint iqh = idx >> 5;
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint qh = bl.block.qh[iqh];
|
||||
const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u);
|
||||
const uint scale = bl.block.scales[iqs / 16];
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
|
||||
const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const u8vec4 g = unpack8(grid);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_XS)
|
||||
|
|
@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
|
|||
block_iq4_xs block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 {
|
||||
block_iq4_xs_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
|
|
@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor
|
|||
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5; // 0..7
|
||||
const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF;
|
||||
const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3;
|
||||
const uint qshift = (idx & 0x10) >> 2; // {0, 4}
|
||||
const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32)
|
||||
|
||||
const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32);
|
||||
|
||||
const uint qsw = bl32.block.qs[qs_w];
|
||||
const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu);
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_iq4nl[qv.x]),
|
||||
float(kvalues_iq4nl[qv.y]),
|
||||
float(kvalues_iq4nl[qv.z]),
|
||||
float(kvalues_iq4nl[qv.w])) * float(dl);
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
|
|
@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
|
|||
block_iq4_nl block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 {
|
||||
block_iq4_nl_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
|
|
@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
|||
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
|
||||
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
|
||||
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
|
||||
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
return f16vec4(
|
||||
float(d) * float(kvalues_iq4nl[q.x]),
|
||||
float(d) * float(kvalues_iq4nl[q.y]),
|
||||
float(d) * float(kvalues_iq4nl[q.z]),
|
||||
float(d) * float(kvalues_iq4nl[q.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
|
|
@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
|||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float d = e8m0_to_fp32(bl.block.e);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uvec4 qv = uvec4(
|
||||
uint(bl.block.qs[iqs]),
|
||||
uint(bl.block.qs[iqs + 1u]),
|
||||
uint(bl.block.qs[iqs + 2u]),
|
||||
uint(bl.block.qs[iqs + 3u]));
|
||||
qv = (qv >> shift) & 0xFu;
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_mxfp4[qv.x]),
|
||||
float(kvalues_mxfp4[qv.y]),
|
||||
float(kvalues_mxfp4[qv.z]),
|
||||
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
|
|
@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF
|
|||
block_nvfp4 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 {
|
||||
block_nvfp4_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
|
@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords
|
|||
qs = (qs >> shift) & 0xF;
|
||||
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint sub = idx >> 4;
|
||||
const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8)
|
||||
const uint shift = (idx & 0x8) >> 1;
|
||||
const float d = ue4m3_to_fp32(bl.block.d[sub]);
|
||||
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_mxfp4[qv.x]),
|
||||
float(kvalues_mxfp4[qv.y]),
|
||||
float(kvalues_mxfp4[qv.z]),
|
||||
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q1_0)
|
||||
#define dequantFuncA dequantFuncQ1_0
|
||||
#define dequantFuncA_v dequantFuncQ1_0_v
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
#define dequantFuncA dequantFuncQ4_0
|
||||
#define dequantFuncA_v dequantFuncQ4_0_v
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define dequantFuncA dequantFuncQ4_1
|
||||
#define dequantFuncA_v dequantFuncQ4_1_v
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define dequantFuncA dequantFuncQ5_0
|
||||
#define dequantFuncA_v dequantFuncQ5_0_v
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define dequantFuncA dequantFuncQ5_1
|
||||
#define dequantFuncA_v dequantFuncQ5_1_v
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define dequantFuncA dequantFuncQ8_0
|
||||
#define dequantFuncA_v dequantFuncQ8_0_v
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define dequantFuncA dequantFuncQ2_K
|
||||
#define dequantFuncA_v dequantFuncQ2_K_v
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define dequantFuncA dequantFuncQ3_K
|
||||
#define dequantFuncA_v dequantFuncQ3_K_v
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define dequantFuncA dequantFuncQ4_K
|
||||
#define dequantFuncA_v dequantFuncQ4_K_v
|
||||
#define fetch_scales fetch_scalesQ4_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#define dequantFuncA_v dequantFuncQ5_K_v
|
||||
#define fetch_scales fetch_scalesQ5_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#define dequantFuncA_v dequantFuncQ6_K_v
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
#define dequantFuncA dequantFuncIQ1_S
|
||||
#define dequantFuncA_v dequantFuncIQ1_S_v
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
#define dequantFuncA dequantFuncIQ1_M
|
||||
#define dequantFuncA_v dequantFuncIQ1_M_v
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
#define dequantFuncA dequantFuncIQ2_XXS
|
||||
#define dequantFuncA_v dequantFuncIQ2_XXS_v
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
#define dequantFuncA dequantFuncIQ2_XS
|
||||
#define dequantFuncA_v dequantFuncIQ2_XS_v
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
#define dequantFuncA dequantFuncIQ2_S
|
||||
#define dequantFuncA_v dequantFuncIQ2_S_v
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
#define dequantFuncA dequantFuncIQ3_XXS
|
||||
#define dequantFuncA_v dequantFuncIQ3_XXS_v
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
#define dequantFuncA dequantFuncIQ3_S
|
||||
#define dequantFuncA_v dequantFuncIQ3_S_v
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
#define dequantFuncA dequantFuncIQ4_XS
|
||||
#define dequantFuncA_v dequantFuncIQ4_XS_v
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#define dequantFuncA_v dequantFuncIQ4_NL_v
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#define dequantFuncA_v dequantFuncMXFP4_v
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
#define dequantFuncA dequantFuncNVFP4
|
||||
#define dequantFuncA_v dequantFuncNVFP4_v
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
#version 460
|
||||
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : require
|
||||
|
||||
void main()
|
||||
{
|
||||
}
|
||||
|
|
@ -11,6 +11,9 @@
|
|||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : enable
|
||||
#endif
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
|
@ -54,6 +57,41 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const
|
|||
}
|
||||
}
|
||||
|
||||
// V=4 vector decode for K/V; dispatches to per-format _v decoders.
|
||||
f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return f16vec4(decodeBufF32(bl_in).block);
|
||||
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return f16vec4(0);
|
||||
}
|
||||
}
|
||||
|
||||
f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return f16vec4(decodeBufF32(bl_in).block);
|
||||
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return f16vec4(0);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#define FADECODEK , faDecodeK, faDecodeKVector
|
||||
#define FADECODEV , faDecodeV, faDecodeVVector
|
||||
#else
|
||||
#define FADECODEK , faDecodeK
|
||||
#define FADECODEV , faDecodeV
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
||||
|
|
@ -259,7 +297,7 @@ void main() {
|
|||
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
|
||||
const bool k_use_decode = (bs_k > 1u);
|
||||
if (k_use_decode) {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK);
|
||||
} else {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
}
|
||||
|
|
@ -325,7 +363,7 @@ void main() {
|
|||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
const bool v_use_decode = (bs_v > 1u);
|
||||
if (v_use_decode) {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV);
|
||||
} else {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,10 +71,12 @@ layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if QUANT_K > 1
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
|
||||
#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector)
|
||||
#define DECODEFUNCA , dequantFuncA, dequantFuncA_v
|
||||
#else
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
#endif
|
||||
#else
|
||||
#define DECODEFUNCA
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1722,11 +1722,18 @@ struct block_nvfp4
|
|||
uint8_t qs[QUANT_K_NVFP4 / 2];
|
||||
};
|
||||
|
||||
struct block_nvfp4_packed32
|
||||
{
|
||||
uint32_t d[QUANT_K_NVFP4 / 16 / 4];
|
||||
uint32_t qs[QUANT_K_NVFP4 / 2 / 4];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
#define QUANT_K QUANT_K_NVFP4
|
||||
#define QUANT_R QUANT_R_NVFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_nvfp4
|
||||
#define A_TYPE_PACKED32 block_nvfp4_packed32
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue