mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
* vulkan: Add bfloat16 support This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16. The extension is required for coopmat multiply support, but matrix-vector multiply trivially promotes bf16 to fp32 and doesn't require the extension. The copy/get_rows shaders also don't require the extension. It's probably possible to fall back to non-coopmat and promote to fp32 when the extension isn't supported, but this change doesn't do that. The coopmat support also requires a glslc that supports the extension, which currently requires a custom build. * vulkan: Support bf16 tensors without the bf16 extension or coopmat support Compile a variant of the scalar mul_mm shader that will promote the bf16 values to float, and use that when either the bf16 extension or the coopmat extensions aren't available. * vulkan: bfloat16 fixes (really works without bfloat16 support now) * vulkan: fix spirv-val failure and reenable -O
33 lines
946 B
Text
33 lines
946 B
Text
#version 450
|
|
|
|
#include "types.comp"
|
|
#include "generic_binary_head.comp"
|
|
|
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void main() {
|
|
const uint i00 = gl_GlobalInvocationID.x;
|
|
const uint i10 = gl_GlobalInvocationID.y;
|
|
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
|
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
|
|
|
if (i00 >= p.ne00) {
|
|
return;
|
|
}
|
|
|
|
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
|
|
|
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
|
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
|
|
|
#if defined(DATA_A_BF16)
|
|
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
|
#else
|
|
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
|
#endif
|
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
|
data_d[d_offset + i00] = D_TYPE(v);
|
|
#else
|
|
data_d[d_offset + i00] = D_TYPE(v);
|
|
#endif
|
|
}
|