mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 11:16:08 +00:00
fix(flash-attn): replace f32 with kv_type and q_type (#23372)
This commit is contained in:
parent
40d5358d3c
commit
5306f4b3b5
1 changed files with 22 additions and 22 deletions
|
|
@ -122,9 +122,9 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
|||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
|
||||
var<workgroup> q_shmem: array<f32, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<f32, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
|
||||
var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
|
@ -169,10 +169,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
let head = f32(head_idx);
|
||||
let slope = select(1.0,
|
||||
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
|
||||
pow(params.m0, head + 1.0),
|
||||
head < params.n_head_log2),
|
||||
params.max_bias > 0.0);
|
||||
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
|
||||
pow(params.m0, head + 1.0),
|
||||
head < params.n_head_log2),
|
||||
params.max_bias > 0.0);
|
||||
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let q_tile_row = elem_idx / HEAD_DIM_QK;
|
||||
|
|
@ -181,7 +181,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
|
||||
q_shmem[elem_idx] = select(
|
||||
0.0,
|
||||
f32(Q[global_q_row_offset + q_col]) * params.scale,
|
||||
Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale,
|
||||
head_q_row < params.seq_len_q);
|
||||
}
|
||||
|
||||
|
|
@ -213,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = f32(k4.x);
|
||||
kv_shmem[kv_off + 1u] = f32(k4.y);
|
||||
kv_shmem[kv_off + 2u] = f32(k4.z);
|
||||
kv_shmem[kv_off + 3u] = f32(k4.w);
|
||||
kv_shmem[kv_off + 0u] = KV_TYPE(k4.x);
|
||||
kv_shmem[kv_off + 1u] = KV_TYPE(k4.y);
|
||||
kv_shmem[kv_off + 2u] = KV_TYPE(k4.z);
|
||||
kv_shmem[kv_off + 3u] = KV_TYPE(k4.w);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -233,18 +233,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
var dot_val = 0.0;
|
||||
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
|
||||
let q_off = q_base + chunk * 4u;
|
||||
let qv = vec4<f32>(
|
||||
let qv = vec4<Q_TYPE>(
|
||||
q_shmem[q_off + 0u],
|
||||
q_shmem[q_off + 1u],
|
||||
q_shmem[q_off + 2u],
|
||||
q_shmem[q_off + 3u]);
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let kv = vec4<f32>(
|
||||
let kv = vec4<KV_TYPE>(
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
kv_shmem[kv_off + 3u]);
|
||||
dot_val += dot(qv, kv);
|
||||
dot_val += dot(vec4<f32>(qv), vec4<f32>(kv));
|
||||
}
|
||||
#ifdef LOGIT_SOFTCAP
|
||||
dot_val = params.logit_softcap * tanh(dot_val);
|
||||
|
|
@ -271,7 +271,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (row_active && kv_local < kv_count) {
|
||||
let p = exp(local_scores[slot] - new_max);
|
||||
p_shmem[subgroup_p_offset + kv_local] = p;
|
||||
p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p);
|
||||
local_sum += p;
|
||||
}
|
||||
}
|
||||
|
|
@ -285,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = f32(v4.x);
|
||||
kv_shmem[kv_off + 1u] = f32(v4.y);
|
||||
kv_shmem[kv_off + 2u] = f32(v4.z);
|
||||
kv_shmem[kv_off + 3u] = f32(v4.w);
|
||||
kv_shmem[kv_off + 0u] = KV_TYPE(v4.x);
|
||||
kv_shmem[kv_off + 1u] = KV_TYPE(v4.y);
|
||||
kv_shmem[kv_off + 2u] = KV_TYPE(v4.z);
|
||||
kv_shmem[kv_off + 3u] = KV_TYPE(v4.w);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -308,12 +308,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
|
||||
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let v4 = vec4<f32>(
|
||||
let v4 = vec4<KV_TYPE>(
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
kv_shmem[kv_off + 3u]);
|
||||
acc += p * v4;
|
||||
acc += f32(p) * vec4<f32>(v4);
|
||||
}
|
||||
out_regs[reg_idx] = acc;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue