ggml-webgpu: Fix dequantization helpers to not pass in pointers (#21872)

* Fix dequantization helpers to not pass in pointers

* Increase XIELU precision
This commit is contained in:
Reese Levine 2026-04-15 09:14:40 -07:00 committed by GitHub
parent a6206958d2
commit 20d3bc2cc8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 223 additions and 192 deletions

View file

@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_u16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
return (word >> shift) & 0xFFFF;
#ifdef DECLARE_BYTE_LOADERS_SRC
fn load_u16_at_src(byte_offset: u32) -> u32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4;
let shift = (byte_offset & 0x3) * 8;
let lo = buf[word_idx];
let hi = buf[word_idx + 1];
let shifted = (lo >> shift) | (hi << (32 - shift));
return select(shifted, lo, shift == 0);
fn load_u32_at_src(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src[word_idx];
let hi = src[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
fn load_f16_at_src(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
let d_bits = (word >> shift) & 0xFFFF;
fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#ifdef DECLARE_BYTE_LOADERS_SRC0
fn load_u16_at_src0(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src0[word_idx];
let hi = src0[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at_src0(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src0(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#endif
#ifdef Q4_1_T

View file

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC
#include "common_decls.tmpl"
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let qh_packed = load_u32_at(&src, block_byte_base + 2);
let d = load_f16_as_f32_at_src(block_byte_base);
let qh_packed = load_u32_at_src(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 8u; j++) {
let q_byte_offset = block_byte_base + 2u + j * 4u;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
let d = load_f16_as_f32_at_src(block_byte_base + 108);
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
scale_vals[0] = load_u32_at_src(block_byte_base + 96);
scale_vals[1] = load_u32_at_src(block_byte_base + 100);
scale_vals[2] = load_u32_at_src(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4);
}
var dst_i = dst_base + offset * 256;
@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
let d = load_f16_as_f32_at_src(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16u; i++) {
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4);
}
var dst_i = dst_base + offset * 256;
@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src, aux0_offset);
let aux1 = load_u32_at(&src, aux1_offset);
let aux0 = load_u32_at_src(aux0_offset);
let aux1 = load_u32_at_src(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
for (var ib: u32 = 0; ib < 32; ib += 4) {
@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
qh_vals[0] = load_u32_at_src(block_byte_base + 66);
qh_vals[1] = load_u32_at_src(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
scale_vals[0] = load_u32_at_src(block_byte_base + 74);
scale_vals[1] = load_u32_at_src(block_byte_base + 78);
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src, sc_sign_offset);
let sc_sign = load_u32_at_src(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src, block_byte_base + 106);
var scale_vals = load_u32_at_src(block_byte_base + 106);
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 32;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View file

@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef FLOAT
const BLOCK_SIZE = 1u;
@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
let qh_packed = load_u32_at_src0(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
}
var sum = 0.0;
@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
}
var sum = 0.0;
@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src0, aux0_offset);
let aux1 = load_u32_at(&src0, aux1_offset);
let aux0 = load_u32_at_src0(aux0_offset);
let aux1 = load_u32_at_src0(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sum = 0.0;
@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib ++) {
@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src0, sc_sign_offset);
let sc_sign = load_u32_at_src0(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
var scale_vals = load_u32_at_src0(block_byte_base + 106);
var sum = 0.0;
for (var ib: u32 = 0; ib < 4; ib++) {
@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 32;
var sum = 0.0;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View file

@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 80u);
let dmin = load_f16_at(&src0, block_byte_base + 82u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 108u);
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// The original loop processes elements in groups of 64
@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at(&src0, block_byte_base + 208u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View file

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC

View file

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC

View file

@ -3,7 +3,9 @@ enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
}
}

View file

@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef VEC
#define VEC_SIZE 4
@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
let d = f32(load_f16_at_src0(block_byte_base));
let m = f32(load_f16_at_src0(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at(&src0, bbase + 208u));
let d = f32(load_f16_at_src0(bbase + 208u));
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
let ql1_u32 = load_u32_at_src0(bbase + q_offset_l);
let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte);
let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View file

@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-9.010913, 9.010913)));
#endif
#ifdef XIELU
let val = f32(src[params.offset_src + src_idx]);
let res =
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
src[params.offset_src + src_idx]) *
TYPE(params.alpha_n) +
TYPE(params.beta) * src[params.offset_src + src_idx],
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] +
TYPE(params.beta) * src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
TYPE(select(
((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val,
params.alpha_p * val * val + params.beta * val,
val > 0.0));
#endif
#ifdef SOFTPLUS
let src_f32 = f32(src[params.offset_src + src_idx]);