diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c deleted file mode 100644 index 2e84badc9..000000000 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ /dev/null @@ -1,955 +0,0 @@ -#include -#include -#include - -#include "hvx-utils.h" - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" - -#ifndef MIN -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif - -#define HTP_GDN_MAX_SV 128 - -struct htp_gdn_context { - struct htp_ops_context * octx; - uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; -}; - -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { - HVX_Vector acc = Q6_V_vzero(); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vmemu(dst + i * epv) = out; - acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); - acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); - } - - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); -} - -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { - HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vmul = hvx_vec_splat_f32(mul); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vmemu(dst + i * epv) = out; - acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); - acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); - } - - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); -} - -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { - HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vmemu(dst + i * epv) = out; - acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); - acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); - } - - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); -} - -static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, const float * restrict mul, - const float * restrict dot, uint32_t n, float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vm = hvx_vmem(mul + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - } - - HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; - hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); -} - -static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, float mul, - const float * restrict dot, uint32_t n, float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - const HVX_Vector vmul = hvx_vec_splat_f32(mul); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - } - - HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; - hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); -} - -static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, const float * restrict src, - const float * restrict scale, const float * restrict dot, uint32_t n, - float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); - const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); - const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); - const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vs = hvx_vmem(src + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); - HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); - HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); - HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vs = hvx_vmem(src + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); - HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); - HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); - HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - } - - HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; - hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); -} - -static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, float * restrict dst4, - float * restrict dst5, float * restrict dst6, float * restrict dst7, - const float * restrict mul, const float * restrict dot, uint32_t n, - float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - HVX_Vector acc4 = Q6_V_vzero(); - HVX_Vector acc5 = Q6_V_vzero(); - HVX_Vector acc6 = Q6_V_vzero(); - HVX_Vector acc7 = Q6_V_vzero(); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vm = hvx_vmem(mul + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); - HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm); - HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm); - HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm); - HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - hvx_vmemu(dst4 + i * epv) = out4; - hvx_vmemu(dst5 + i * epv) = out5; - hvx_vmemu(dst6 + i * epv) = out6; - hvx_vmemu(dst7 + i * epv) = out7; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); - acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); - acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); - acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm); - HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm); - HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); - HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); - acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); - acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); - acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); - } - - HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; - HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; - hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); - hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); -} - -static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, float * restrict dst4, - float * restrict dst5, float * restrict dst6, float * restrict dst7, - float mul, const float * restrict dot, uint32_t n, float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - HVX_Vector acc4 = Q6_V_vzero(); - HVX_Vector acc5 = Q6_V_vzero(); - HVX_Vector acc6 = Q6_V_vzero(); - HVX_Vector acc7 = Q6_V_vzero(); - const HVX_Vector vmul = hvx_vec_splat_f32(mul); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); - HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul); - HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul); - HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul); - HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - hvx_vmemu(dst4 + i * epv) = out4; - hvx_vmemu(dst5 + i * epv) = out5; - hvx_vmemu(dst6 + i * epv) = out6; - hvx_vmemu(dst7 + i * epv) = out7; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); - acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); - acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); - acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); - HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); - HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); - HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul); - HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul); - HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); - HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); - acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); - acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); - acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); - } - - HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; - HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; - hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); - hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); -} - -static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1, - float * restrict dst2, float * restrict dst3, float * restrict dst4, - float * restrict dst5, float * restrict dst6, float * restrict dst7, - const float * restrict src, const float * restrict scale, - const float * restrict dot, uint32_t n, float * restrict sums) { - HVX_Vector acc0 = Q6_V_vzero(); - HVX_Vector acc1 = Q6_V_vzero(); - HVX_Vector acc2 = Q6_V_vzero(); - HVX_Vector acc3 = Q6_V_vzero(); - HVX_Vector acc4 = Q6_V_vzero(); - HVX_Vector acc5 = Q6_V_vzero(); - HVX_Vector acc6 = Q6_V_vzero(); - HVX_Vector acc7 = Q6_V_vzero(); - const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); - const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); - const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); - const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); - const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]); - const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]); - const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]); - const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]); - - const uint32_t epv = 128 / sizeof(float); - const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; - for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vs = hvx_vmem(src + i * epv); - HVX_Vector vdot = hvx_vmem(dot + i * epv); - - HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); - HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); - HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); - HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); - HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4)); - HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5)); - HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6)); - HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7)); - - hvx_vmemu(dst0 + i * epv) = out0; - hvx_vmemu(dst1 + i * epv) = out1; - hvx_vmemu(dst2 + i * epv) = out2; - hvx_vmemu(dst3 + i * epv) = out3; - hvx_vmemu(dst4 + i * epv) = out4; - hvx_vmemu(dst5 + i * epv) = out5; - hvx_vmemu(dst6 + i * epv) = out6; - hvx_vmemu(dst7 + i * epv) = out7; - - acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); - acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); - acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); - acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); - acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); - acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); - acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); - acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); - } - - if (tail) { - const uint32_t off = nvec * epv; - HVX_Vector vs = hvx_vmem(src + off); - HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); - HVX_Vector zero = Q6_V_vzero(); - - HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); - HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); - HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); - HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4)); - HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5)); - HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); - HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); - - acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); - acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); - acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); - acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); - acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); - acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); - acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); - acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); - } - - HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; - HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; - hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); - hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); -} - -static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) { - struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; - struct htp_ops_context * octx = gctx->octx; - - const struct htp_tensor * q = octx->src[0]; - const struct htp_tensor * k = octx->src[1]; - const struct htp_tensor * v = octx->src[2]; - const struct htp_tensor * g = octx->src[3]; - const struct htp_tensor * beta = octx->src[4]; - const struct htp_tensor * state = octx->src[5]; - const struct htp_tensor * dst = octx->dst; - - const uint32_t S_v = v->ne[0]; - const uint32_t H = v->ne[1]; - const uint32_t n_tokens = v->ne[2]; - const uint32_t n_seqs = v->ne[3]; - - const uint32_t total_rows = H * n_seqs; - if (ith >= total_rows) { - return; - } - - const uint32_t rq3 = n_seqs / q->ne[3]; - const uint32_t rk3 = n_seqs / k->ne[3]; - const float scale = 1.0f / sqrtf((float) S_v); - - float * dst_base = (float *) (uintptr_t) dst->data; - float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs; - const float * state_in_base = (const float *) (uintptr_t) state->data; - - const bool kda = (g->ne[0] == S_v); - float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); - - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; - - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; - - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; - - float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; - - for (uint32_t t = 0; t < n_tokens; ++t) { - const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + - (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]); - const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + - (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]); - const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + - (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]); - const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + - (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]); - const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + - (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); - - if (kda) { - hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); - - uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; - } - } else { - const float gate = expf(g_t[0]); - uint32_t j = 0; - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; - } - } - - attn_data += (uint64_t) S_v * H; - } - } -} - -static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { - struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; - struct htp_ops_context * octx = gctx->octx; - - const struct htp_tensor * q = octx->src[0]; - const struct htp_tensor * k = octx->src[1]; - const struct htp_tensor * v = octx->src[2]; - const struct htp_tensor * g = octx->src[3]; - const struct htp_tensor * beta = octx->src[4]; - const struct htp_tensor * state = octx->src[5]; - const struct htp_tensor * dst = octx->dst; - - const uint32_t S_v = v->ne[0]; - const uint32_t H = v->ne[1]; - const uint32_t n_seqs = v->ne[3]; - - const uint32_t total_rows = H * n_seqs; - if (ith >= total_rows) { - return; - } - - const uint32_t rq3 = n_seqs / q->ne[3]; - const uint32_t rk3 = n_seqs / k->ne[3]; - const float scale = 1.0f / sqrtf((float) S_v); - - float * dst_base = (float *) (uintptr_t) dst->data; - float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs; - const float * state_in_base = (const float *) (uintptr_t) state->data; - - const bool kda = (g->ne[0] == S_v); - float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); - - dma_queue * dma = octx->ctx->dma[ith]; - - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; - } - - for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; - - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; - - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - float * s_work; - - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } - - float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; - - const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + - (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]); - const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + - (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]); - const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + - (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]); - const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + - (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]); - const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + - (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); - - if (kda) { - hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); - - uint32_t j = 0; - for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; - gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, - local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, - local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; - } - } else { - const float gate = expf(g_t[0]); - uint32_t j = 0; - for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; - gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, - gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, - local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } - gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } - } - for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; - } - } - - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - } - } -} - -int op_gated_delta_net(struct htp_ops_context * octx) { - const struct htp_tensor * q = octx->src[0]; - const struct htp_tensor * k = octx->src[1]; - const struct htp_tensor * v = octx->src[2]; - const struct htp_tensor * g = octx->src[3]; - const struct htp_tensor * beta = octx->src[4]; - const struct htp_tensor * state = octx->src[5]; - const struct htp_tensor * dst = octx->dst; - - if (!q || !k || !v || !g || !beta || !state || !dst) { - return HTP_STATUS_INVAL_PARAMS; - } - - if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 || - g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 || - dst->type != HTP_TYPE_F32) { - return HTP_STATUS_NO_SUPPORT; - } - - const uint32_t S_v = v->ne[0]; - const uint32_t H = v->ne[1]; - const uint32_t n_tokens = v->ne[2]; - const uint32_t n_seqs = v->ne[3]; - - if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { - return HTP_STATUS_NO_SUPPORT; - } - if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { - return HTP_STATUS_NO_SUPPORT; - } - if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 || - q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 || - (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { - return HTP_STATUS_NO_SUPPORT; - } - if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { - return HTP_STATUS_NO_SUPPORT; - } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { - return HTP_STATUS_NO_SUPPORT; - } - - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { - return HTP_STATUS_OK; - } - - struct htp_gdn_context gctx; - gctx.octx = octx; - gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads; - gctx.state_bytes = (size_t) S_v * S_v * sizeof(float); - - size_t state_aligned = (size_t) S_v * S_v * sizeof(float); - state_aligned = (state_aligned + 127) & ~(size_t)127; - - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; - - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } - - if (n_tokens == 1) { - worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); - } else { - worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads); - } - - return HTP_STATUS_OK; -} diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c deleted file mode 100644 index 8a6d7c14e..000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ /dev/null @@ -1,1840 +0,0 @@ -// HMX-accelerated Flash Attention for prefill (neq1 >= 32). -// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. - -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include -#include -#include -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "hex-dma.h" -#include "hmx-profile.h" -#include "hmx-queue.h" -#include "hmx-utils.h" -#include "htp-ctx.h" -#include "htp-ops.h" -#include "hvx-dump.h" -#include "hvx-reduce.h" -#include "hvx-utils.h" -#include "vtcm-utils.h" -#include "worker-pool.h" - -// ============================================================================ -// Constants -// ============================================================================ - -// Tile constants from hmx-utils.h -// HMX_FP16_TILE_N_ROWS = 32 -// HMX_FP16_TILE_N_COLS = 32 -// HMX_FP16_TILE_N_ELMS = 1024 -// HMX_FP16_TILE_SIZE = 2048 - -// ============================================================================ -// Dynamic block size computation (GQA-aware) -// ============================================================================ - -// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. -// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. -// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales -// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { - const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); - const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] - const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong - const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf - const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf - const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved - const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved - const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] - const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] - const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. - const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); - const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); - const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); - const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); - - return q_tile_size * 1 // Q tiles - + o_tile_size * 2 // O ping-pong - + k_dma_size * 2 // K DMA x2 - + v_dma_size * 2 // V DMA x2 - + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles - + s_tile_size * 2 // S + P - + d_tile_size * 1 // D (diagonal matrix) - + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum - + row_vec_size * 2 * n_threads // per-thread softmax row scratch - + m_buf_size * 1 // mask VTCM buffer [Br rows] - + slopes_size // Slopes - + 256 * 2; // HMX scales (id + qk) -} - -// ============================================================================ -// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) -// ============================================================================ -// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be -// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per -// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). -// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). -static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { - const HVX_Vector zero_v = Q6_V_vzero(); - const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 - - // k = round_toward_neg_inf(x); f = (float)k; frac = x - f - HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); - HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 - HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 - - HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 - - // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 - HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 - y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 - y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 - y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 - y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 - y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x - y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 - - // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k - y = Q6_Vhf_equals_Vqf16(y); - HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); - y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); - HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); - y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); - return Q6_V_vmux_QVV(q_underflow, zero_v, y); -} - -#define FA_MIN_KV_BLOCKS 3 - -// Cost-based (Br, Bc) search for flash attention with pipeline constraint. -// -// VTCM model (same as before): -// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc -// -// Cost model (minimization objective): -// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) -static int hmx_fa_find_chunk_size(size_t * Br_out, - size_t * Bc_out, - size_t gqa_factor, - size_t DK, - size_t DV, - size_t qo_len, - size_t kv_len, - size_t vtcm_budget, - size_t n_threads) { - const size_t T = HMX_FP16_TILE_N_ROWS; // 32 - const size_t br_unit = hmx_ceil_div(T, gqa_factor); - // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax - // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) - // that would race across r0 blocks if the last dual-tile is half-occupied. - // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. - const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 - const size_t fp16 = sizeof(__fp16); - - // Approximate per-unit VTCM costs (without per-buffer alignment padding). - const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors - const size_t per_gbr2 = fp16; // D diagonal matrix - const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs - const size_t per_gbr_bc = 2 * fp16; // S + P - - const size_t overhead = 256 * 2 + 13 * 4096; - - if (vtcm_budget <= overhead) { - return -1; - } - const size_t usable = vtcm_budget - overhead; - - // Br_max: largest Br aligned to br_unit that does not exceed qo_len. - const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; - - // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. - // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); - const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : - (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); - // Cost coefficients calibrated from profiling - const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store - const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers - - size_t best_cost = SIZE_MAX, best_mn = 0; - size_t best_Br = 0, best_Bc = 0; - - for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { - const size_t g_br = hex_align_up(gqa_factor * Br, T); - - // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 - const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; - if (gbr_cost >= usable) { - if (Br == br_unit) { - break; - } - continue; - } - - // Analytically solve for max Bc: - // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) - // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. - const size_t remain = usable - gbr_cost; - const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; - size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); - if (Bc < bc_unit) { - if (Br == br_unit) { - break; - } - continue; - } - - // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { - Bc -= bc_unit; - } - if (Bc < bc_unit) { - if (Br == br_unit) { - break; - } - continue; - } - - const size_t q_blocks = (qo_len + Br - 1) / Br; - const size_t kv_blocks = (kv_len + Bc - 1) / Bc; - const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); - const size_t mn = Br * Bc; - - if (cost < best_cost || (cost == best_cost && mn > best_mn)) { - best_cost = cost; - best_mn = mn; - best_Br = Br; - best_Bc = Bc; - } - - if (Br == br_unit) { - break; - } - } - - if (best_Br == 0) { - return -1; - } - - *Br_out = best_Br; - *Bc_out = best_Bc; - return 0; -} - -// ============================================================================ -// Tile interleave / extract helpers -// ============================================================================ - -// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets - -// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 -// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile -static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { - 0 * 136, 0 * 136 + 6, - 1 * 136, 1 * 136 + 6, - 2 * 136, 2 * 136 + 6, - 3 * 136, 3 * 136 + 6, - 4 * 136, 4 * 136 + 6, - 5 * 136, 5 * 136 + 6, - 6 * 136, 6 * 136 + 6, - 7 * 136, 7 * 136 + 6, - 8 * 136, 8 * 136 + 6, - 9 * 136, 9 * 136 + 6, - 10 * 136, 10 * 136 + 6, - 11 * 136, 11 * 136 + 6, - 12 * 136, 12 * 136 + 6, - 13 * 136, 13 * 136 + 6, - 14 * 136, 14 * 136 + 6, - 15 * 136, 15 * 136 + 6, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, - 0, 0, -}; - -// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h - -// ============================================================================ -// HMX Flash Attention context (GQA-merged) -// ============================================================================ - -struct hmx_fa_context { - const struct htp_ops_context * octx; - bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 - uint32_t n_threads; - - // Op parameters - float scale; - float max_bias; - float logit_softcap; - uint32_t n_head_log2; - float m0, m1; - - // Dimensions - uint32_t DK, DV; - uint32_t n_kv; // kv_len - uint32_t n_kv_heads; // number of KV heads - uint32_t n_heads; // number of Q heads - uint32_t G; // GQA factor = n_heads / n_kv_heads - uint32_t n_kv_blocks; - uint32_t neq1; // Q token count - - // Types - bool is_q_fp32; - bool is_dst_fp32; - - // Dynamic block sizes - uint32_t Br; // Q tokens per block (before GQA expansion) - uint32_t Bc; - uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim - - // VTCM buffers (allocated by vtcm_seq_alloc) - __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] - __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] - __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] - __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] - __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) - __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] - __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] - __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] - HVX_Vector * vtcm_m_vec; // Row max [g_br] - HVX_Vector * vtcm_l_vec; // Row sum [g_br] - HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] - HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] - HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] - uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) - uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) - __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block - __fp16 * vtcm_slopes; // ALiBi slopes [g_br] - size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) - size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer - bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) -}; - -// ============================================================================ -// Multi-thread K interleave phase -// ============================================================================ - -typedef struct { - struct hmx_fa_context * factx; - int kv_rows; - size_t src_stride; - size_t buf_idx; -} fa_k_int_args_t; - -static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { - fa_k_int_args_t * args = (fa_k_int_args_t *) data; - struct hmx_fa_context * factx = args->factx; - - const int total_rows = args->kv_rows; - const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) - const int start = i * rows_per_t; - const int end = hex_smin(start + rows_per_t, total_rows); - - if (start >= total_rows) { - return; - } - - hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, - (int) args->src_stride, start, end); -} - -static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { - worker_pool_context_t wp = factx->octx->ctx->worker_pool; - fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; - if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { - worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); - } else { - fa_k_interleave_thread(1, 0, &args); - } -} - -// ============================================================================ -// Multi-thread V interleave phase -// ============================================================================ - -typedef struct { - struct hmx_fa_context * factx; - int kv_rows; - size_t src_stride; - size_t buf_idx; - size_t n_col_tiles; -} fa_v_int_args_t; - -static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { - fa_v_int_args_t * args = (fa_v_int_args_t *) data; - struct hmx_fa_context * factx = args->factx; - - const int total_rows = args->kv_rows; - const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); - const int start = i * rows_per_t; - const int end = hex_smin(start + rows_per_t, total_rows); - - if (start >= total_rows) { - return; - } - - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, - (int) args->src_stride, (int) args->n_col_tiles, start, end); -} - -static void fa_phase_v_interleave(struct hmx_fa_context * factx, - int kv_rows, - size_t src_stride, - size_t buf_idx, - size_t n_col_tiles) { - worker_pool_context_t wp = factx->octx->ctx->worker_pool; - fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; - if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { - worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); - } else { - fa_v_interleave_thread(1, 0, &args); - } -} - -// ============================================================================ -// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 -// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. -// Each thread owns a disjoint range of row pairs; writes target distinct tile -// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no -// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded -// by the caller before dispatching. -// ============================================================================ - -typedef struct { - struct hmx_fa_context * factx; - const struct htp_tensor * q; - uint32_t q_start; - uint32_t kv_head; - uint32_t ib3; - size_t n_rows_g; -} fa_q_load_args_t; - -static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { - fa_q_load_args_t * args = (fa_q_load_args_t *) data; - struct hmx_fa_context * factx = args->factx; - - const size_t n_rows_g = args->n_rows_g; - const size_t G = factx->G; - const size_t DK = factx->DK; - - // Partition row pairs across threads. Keep each thread's start even so r/r+1 - // are always in the same thread's range. - const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); - const size_t start = (size_t) i * rows_per_t; - const size_t end = hex_smin(start + rows_per_t, n_rows_g); - - if (start >= n_rows_g) { - return; - } - - const struct htp_tensor * q = args->q; - const uint32_t q_start = args->q_start; - const uint32_t kv_head = args->kv_head; - const uint32_t ib3 = args->ib3; - - for (size_t r = start; r < end; r += 2) { - const bool next_row_valid = (r + 1) < n_rows_g; - - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; - - const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + - (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; - const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + - (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : - NULL; - - size_t r0 = r / HMX_FP16_TILE_N_ROWS; - size_t r1 = r % HMX_FP16_TILE_N_ROWS; - __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; - - if (factx->is_q_fp32) { - const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; - const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; - - for (uint32_t d = 0; d < DK / 32; ++d) { - HVX_Vector v0 = pv_in0[d]; - HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); - HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); - - HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); - out_tile[r1 / 2] = v_hf; - } - } else { - const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; - const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; - - for (uint32_t d = 0; d < DK / 64; ++d) { - HVX_Vector v0 = pv_in0[d]; - HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); - HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); - - __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; - HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; - HVX_Vector * pv_out1 = pv_out0 + 16; - - *pv_out0 = Q6_V_lo_W(vp); - *pv_out1 = Q6_V_hi_W(vp); - } - } - } -} - -static void fa_phase_q_load(struct hmx_fa_context * factx, - const struct htp_tensor * q, - uint32_t q_start, - uint32_t kv_head, - uint32_t ib3, - size_t n_rows_g) { - worker_pool_context_t wp = factx->octx->ctx->worker_pool; - fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; - // Require >= 2 row pairs per thread so partitioning is worthwhile. - if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { - worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); - } else { - fa_q_load_thread(1, 0, &args); - } -} - -// ============================================================================ -// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or -// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a -// disjoint row range; writes target distinct dst rows (different q_idx/h_idx -// pairs produced by r/G and r%G), so there is no write conflict. -// ============================================================================ - -typedef struct { - struct hmx_fa_context * factx; - const struct htp_tensor * dst; - const __fp16 * o_tile_src; - uint32_t q_start; - uint32_t kv_head; - uint32_t ib3; - size_t n_rows_g; -} fa_o_store_args_t; - -static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { - fa_o_store_args_t * args = (fa_o_store_args_t *) data; - struct hmx_fa_context * factx = args->factx; - - const size_t n_rows_g = args->n_rows_g; - const size_t G = factx->G; - const size_t DV = factx->DV; - - const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); - const size_t start = (size_t) i * rows_per_t; - const size_t end = hex_smin(start + rows_per_t, n_rows_g); - - if (start >= n_rows_g) { - return; - } - - const struct htp_tensor * dst = args->dst; - const __fp16 * o_tile_src = args->o_tile_src; - const uint32_t q_start = args->q_start; - const uint32_t kv_head = args->kv_head; - const uint32_t ib3 = args->ib3; - - for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; - - // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> - // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. - uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + - (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; - - size_t r0 = r / HMX_FP16_TILE_N_ROWS; - size_t r1 = r % HMX_FP16_TILE_N_ROWS; - const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; - - if (factx->is_dst_fp32) { - float * out = (float *) dst_row; - for (uint32_t d = 0; d < DV / 32; ++d) { - const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); - HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); - if (r1 % 2 == 0) { - *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); - } else { - *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); - } - } - } else { - __fp16 * out = (__fp16 *) dst_row; - for (uint32_t d = 0; d < DV / 64; ++d) { - const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; - const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; - const HVX_Vector * pv_in1 = pv_in0 + 16; - HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); - if (r1 % 2 == 0) { - *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); - } else { - *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); - } - } - } - } -} - -static void fa_phase_o_store(struct hmx_fa_context * factx, - const struct htp_tensor * dst, - const __fp16 * o_tile_src, - uint32_t q_start, - uint32_t kv_head, - uint32_t ib3, - size_t n_rows_g) { - worker_pool_context_t wp = factx->octx->ctx->worker_pool; - fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; - if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { - worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); - } else { - fa_o_store_thread(1, 0, &args); - } -} - -// ============================================================================ -// Multi-thread softmax phase + serial m/l update + build_D -// ============================================================================ - -typedef struct { - struct hmx_fa_context * factx; - size_t kv_rows; - size_t n_rows_g; - size_t n_col_tiles; - size_t n_tiles_per_bc; - size_t n_row_tiles; - size_t n_row_tiles_g_br; - uint32_t Bc; - uint32_t G; - uint32_t kv_head; - uint32_t kv_start; - uint32_t q_start; - uint32_t ib3; - bool has_alibi; // true when max_bias != 0 (need slope * mask + add) - - // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) - // slope[r] = 1.0 when max_bias == 0 (no ALiBi) - // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) - __fp16 * slopes; - - // Mask info (preloaded before softmax) - const struct htp_tensor * mask; - const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) - size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer -} fa_softmax_args_t; - -static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { - fa_softmax_args_t * args = (fa_softmax_args_t *) data; - struct hmx_fa_context * factx = args->factx; - - const size_t n_rows_g = args->n_rows_g; - const size_t kv_rows = args->kv_rows; - const size_t Bc = args->Bc; - const size_t G = args->G; - const size_t n_tiles_per_bc = args->n_tiles_per_bc; - const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); - - // Partition r_vec_idx across threads - const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); - const size_t vec_start = i * vecs_per_t; - const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); - - if (vec_start >= n_row_vec_cnt) { - return; - } - - // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride - const size_t row_buf_stride = factx->row_buf_stride; - HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; - HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; - - const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); - - // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. - // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different - // bit layout, so a later hf-domain read would silently produce wrong values. - // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax - // are hf; vtcm_l_vec is qf16 — don't mix them up. - - for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { - HVX_Vector rowmax_acc_v = v_neg_inf; - HVX_Vector rowsum_acc_v = Q6_V_vzero(); - HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; - - for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { - int r = r_vec_idx * 64 + r_vec_off; - if (r >= (int) hex_align_up(n_rows_g, 2)) { - break; - } - - int r0 = r / HMX_FP16_TILE_N_ROWS; - int r1 = r % HMX_FP16_TILE_N_ROWS; - - const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; - __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; - - // Decode 2 rows from S tiles into per-thread row buffers - HVX_Vector * pv_row_buf0 = my_row_buf0; - HVX_Vector * pv_row_buf1 = my_row_buf1; - for (size_t c = 0; c < kv_rows; c += 64) { - const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; - const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; - const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; - - HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); - *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); - *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); - } - - // Apply softcap if enabled (in F32 precision) - if (factx->logit_softcap != 0.0f) { - // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in - // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT - // of qk_scale in this configuration (see scale setup) so tanh sees - // the physical QK/(√d·c) argument. - float cap = factx->logit_softcap; -#ifdef HMX_FA_USE_EXP2_HF - cap *= 1.44269504f; // log2(e) -#endif - const HVX_Vector v_cap = hvx_vec_splat_f32(cap); - for (size_t c = 0; c < kv_rows; c += 64) { - size_t ci = c / 64; - - HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); - HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); - HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); - t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); - t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); - my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); - - HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); - HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); - HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); - t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); - t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); - my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); - } - } - - // Apply mask & compute rowmax(S) - // - // Optimizations over baseline: - // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the - // slope multiplication — still add mask (additive bias) but - // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. - // B. GQA mask row dedup: G consecutive Q rows share one mask row - // (qi = r / G). Reuse mask vector when qi is unchanged between - // row0 and row1 (saves ~75% of VTCM loads for G=4). - - // ALiBi slopes — only needed when has_alibi (scheme A) - HVX_Vector v_slope0, v_slope1; - if (args->has_alibi) { - v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); - v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); - } - - const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) - - HVX_Vector v_s_rowmax0 = v_neg_inf; - HVX_Vector v_s_rowmax1 = v_neg_inf; - for (size_t c = 0; c < kv_rows; c += 64) { - size_t ci = c / 64; - const size_t ne = hex_smin(kv_rows - c, 64); - HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); - - if (args->mask) { - HVX_Vector v_mask0, v_mask1; - - if (args->mask_vtcm) { - // Read mask from VTCM buffer (DMA'd per KV block). - // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; - v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); - v_mask1 = v_neg_inf; - if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; - if (qi1 == qi0) { - v_mask1 = v_mask0; // scheme B: reuse — same mask row - } else { - v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); - } - } - } else { - // Fallback: read mask directly from DDR (when mask->ne[2] > 1). - const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; - const uint32_t im2_0 = h_idx0 % mask->ne[2]; - const uint32_t im3_0 = args->ib3 % mask->ne[3]; - - const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + - im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; - v_mask0 = *(const HVX_UVector *) m0_ptr; - v_mask1 = v_neg_inf; - - if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); - if (q_idx1 == q_idx0) { - // scheme B: same mask row in DDR path - v_mask1 = v_mask0; - } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; - const uint32_t im2_1 = h_idx1 % mask->ne[2]; - const uint32_t im3_1 = args->ib3 % mask->ne[3]; - const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + - im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; - v_mask1 = *(const HVX_UVector *) m1_ptr; - } - } - } - - // Threshold: mask values below -16.0 are treated as -inf (causal mask). - HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); - HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); - - if (args->has_alibi) { - // ALiBi path: S += slope * mask (full mul + add) - HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); - HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); - my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); - my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); - } else { - // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul - // but still add mask (additive positional bias). vmux - // clamps mask < -16 to -inf as a numerical safeguard. - my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); - my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); - } - } else { - if (ne < 64) { - my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); - my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); - } - } - - v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); - v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); - } - - v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); - v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); - - // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. - // vror brings the target lane to lane 0, then extract + re-splat. - HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); - HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); - - // HVX max — both operands are splats, so result is splat of m_new. - HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); - HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); - - // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. - // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. - // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. - { - HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); - HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); - HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); - HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); - HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); - rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); - rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); - } - - // Compute P = exp(S - m_new), using HVX exp - const HVX_Vector v_zero = Q6_V_vzero(); - HVX_Vector v_p_rowsum0 = v_zero; - HVX_Vector v_p_rowsum1 = v_zero; - -#ifdef HMX_FA_USE_EXP2_HF - // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): - // P = exp2(S - m_new) - for (size_t c = 0; c < kv_rows; c += 64) { - size_t ci = c / 64; - HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); - HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); - - HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); - HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); -#else - // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, - for (size_t c = 0; c < kv_rows; c += 64) { - size_t ci = c / 64; - HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); - HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); - - HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); - HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); - HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); - HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); - - HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); - HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); - HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); - HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); -#endif - // Write P to tile format. Dual-tile pattern assumes Bc is a - // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), - // so both tile halves are always in the current r0 block. - __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; - HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; - HVX_Vector * pv_p_out1 = pv_p_out0 + 16; - - HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); - *pv_p_out0 = Q6_V_lo_W(vp_p_dual); - *pv_p_out1 = Q6_V_hi_W(vp_p_dual); - - HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); - HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); - - v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); - v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); - } - - HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); - HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); - { - // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. - HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); - HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); - - HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); - HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); - HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); - HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); - HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); - rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); - rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); - } - } - - factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; - factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; - } -} - -// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). -// -// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter -// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection -// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler -// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread -// reads stale VTCM (PPL → ~vocab-size). -static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, - size_t n_rows_g, - size_t n_row_tiles, - size_t n_row_tiles_g_br) { - // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete - HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; - - const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); - for (size_t i = 0; i < n_row_vec_cnt; ++i) { - HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; - HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); - HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); - -#ifdef HMX_FA_USE_EXP2_HF - // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. - HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); -#else - HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); - HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); - HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); - HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); -#endif - - HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); - v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); - - factx->vtcm_m_vec[i] = v_m_curr; - factx->vtcm_l_vec[i] = v_l_curr; - mvec_exp_m_diff[i] = v_exp_m_diff; - } - - // Build diagonal tile D = diag(exp(m_diff)) - const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; - const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); - for (size_t i = 0; i < n_row_tiles; ++i) { - const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); - __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); - // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the - // compiler may not recognize the volatile read below as aliasing and - // could reorder it before the scatter, defeating the HW drain. - __asm__ __volatile__("" ::: "memory"); - // Per-tile drain: scatter regions are disjoint (stride > tile size), - // so a single drain at tile 0 does NOT retire later tiles' entries. - (void) *(volatile HVX_Vector *) out_base; - } -} - -// Build D = diag(1/l) tile for the final O = D @ O normalization. -// -// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from -// being hoisted past the subsequent hmx_queue_push at the o_norm call site. -static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, - size_t n_row_tiles, - size_t n_row_tiles_g_br) { - const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; - const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); - const HVX_Vector one = hvx_vec_splat_f32(1.0f); - - HVX_Vector v_content = Q6_V_vzero(); - for (size_t i = 0; i < n_row_tiles; ++i) { - if ((i % 2) == 0) { - HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); - HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); - HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); - HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); - v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); - } else { - v_content = Q6_V_vror_VR(v_content, 64); - } - - __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); - // Compiler barrier — see fa_ml_update_and_build_d for rationale. - __asm__ __volatile__("" ::: "memory"); - (void) *(volatile HVX_Vector *) out_base; - } -} - -// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D -static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, - fa_softmax_args_t * sargs, - size_t n_row_tiles, - size_t n_row_tiles_g_br) { - worker_pool_context_t wp = factx->octx->ctx->worker_pool; - const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); - - if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { - uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); - worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); - } else { - fa_softmax_thread(1, 0, sargs); - } - // barrier implicit in worker_pool_run_func return - - fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); -} - -// ============================================================================ -// HMX job structs and worker functions -// ============================================================================ - -typedef struct { - const __fp16 * q_tiles; - const __fp16 * k_tiles; - __fp16 * s_tiles; - size_t n_row_tiles; - size_t n_col_tiles; - size_t n_dot_tiles; // DK / 32 - size_t n_tiles_per_bc; - uint8_t * hmx_scales; -} hmx_fa_qk_job_t; - -static void hmx_fa_qk_dot_worker(void * data) { - hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; - const size_t n_row_tiles = job->n_row_tiles; - const size_t n_col_tiles = job->n_col_tiles; - const size_t n_dot_tiles = job->n_dot_tiles; - const size_t n_tiles_per_bc = job->n_tiles_per_bc; - const __fp16 * restrict q_tiles = job->q_tiles; - const __fp16 * restrict k_tiles = job->k_tiles; - __fp16 * restrict s_tiles = job->s_tiles; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *) job->hmx_scales); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < n_col_tiles; ++c) { - const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; - const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; - __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; - - for (size_t k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(out_tile, 0); - } - } -} - -typedef struct { - __fp16 * o_curr; - const __fp16 * o_prev; - const __fp16 * p_tiles; - const __fp16 * v_tiles; - const __fp16 * d_tiles; - uint8_t * hmx_scales; - size_t n_row_tiles; - size_t n_col_tiles; - size_t n_row_tiles_g_br; - size_t n_tiles_per_bc; - size_t DV; -} hmx_fa_o_update_job_t; - -static void hmx_fa_o_update_worker(void * data) { - hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; - const size_t n_row_tiles = job->n_row_tiles; - const size_t n_col_tiles = job->n_col_tiles; - const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; - const size_t n_tiles_per_bc = job->n_tiles_per_bc; - const size_t DV_tiles = job->DV / 32; - const __fp16 * restrict d_tiles = job->d_tiles; - const __fp16 * restrict p_tiles = job->p_tiles; - const __fp16 * restrict v_tiles = job->v_tiles; - const __fp16 * restrict o_prev = job->o_prev; - __fp16 * restrict o_curr = job->o_curr; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(DV_tiles > 0); - - Q6_bias_mxmem2_A((void *) job->hmx_scales); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < DV_tiles; ++c) { - // D[r,r] @ O_prev[r,c] — only the diagonal tile - const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); - - // P @ V (accumulate on same accumulator) - const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; - const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; - for (size_t k = 0; k < n_col_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); - p_tile_in += HMX_FP16_TILE_N_ELMS; - v_tile_in += HMX_FP16_TILE_N_ELMS; - } - - __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - Q6_mxmem_AR_after_hf(o_tile_out, 0); - } - } -} - -typedef struct { - __fp16 * o_curr; // output (row-major tile layout) - const __fp16 * o_prev; // input (column-major tile layout) - const __fp16 * d_tiles; // diag(1/l) tiles - uint8_t * hmx_scales; - size_t n_row_tiles; - size_t n_row_tiles_g_br; - size_t DV; -} hmx_fa_o_norm_job_t; - -static void hmx_fa_o_norm_worker(void * data) { - hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; - const size_t n_row_tiles = job->n_row_tiles; - const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; - const size_t DV_tiles = job->DV / 32; - const __fp16 * restrict d_tiles = job->d_tiles; - const __fp16 * restrict o_prev = job->o_prev; - __fp16 * restrict o_curr = job->o_curr; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(DV_tiles > 0); - - Q6_bias_mxmem2_A((void *) job->hmx_scales); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < DV_tiles; ++c) { - const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; - - Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); - Q6_mxmem_AR_after_hf(o_out, 0); - } - } -} - -// Populate per-GQA-row ALiBi slopes for a given KV head. -// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. -// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). -// When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, - const struct hmx_fa_context * factx, - uint32_t kv_head, - size_t n_rows_g) { - if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } - return; - } - - const uint32_t G = factx->G; - const uint32_t n_head_log2 = factx->n_head_log2; - const float m0 = factx->m0; - const float m1 = factx->m1; - - for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); - } -} - -// ============================================================================ -// Core HMX flash attention algorithm (GQA-merged) -// ============================================================================ - -int hmx_flash_attn_ext(struct htp_ops_context * octx) { - const struct htp_tensor * q = octx->src[0]; - const struct htp_tensor * k = octx->src[1]; - const struct htp_tensor * v = octx->src[2]; - const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; - const struct htp_tensor * dst = octx->dst; - - struct htp_context * const ctx = octx->ctx; - - if (!ctx->hmx_enabled) { - return HTP_STATUS_NO_SUPPORT; - } - - // Dimensions - const uint32_t neq0 = q->ne[0]; // head_dim (DK) - const uint32_t neq1 = q->ne[1]; // n_tokens - const uint32_t neq2 = q->ne[2]; // n_heads - const uint32_t neq3 = q->ne[3]; // n_seqs - - const uint32_t nek0 = k->ne[0]; // head_dim - const uint32_t nek1 = k->ne[1]; // kv_len - - const uint32_t nev0 = v->ne[0]; // head_dim (DV) - - const uint32_t DK = neq0; - const uint32_t DV = nev0; - - // HMX requires head_dim to be multiple of 32 - if (DK % 32 != 0 || DV % 32 != 0) { - return HTP_STATUS_NO_SUPPORT; - } - if (neq1 < 32) { - return HTP_STATUS_NO_SUPPORT; - } - - // GQA factor - const uint32_t n_kv_heads = k->ne[2]; - const uint32_t G = neq2 / n_kv_heads; - - // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; - - // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) - size_t Br, Bc; - const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { - return HTP_STATUS_VTCM_TOO_SMALL; - } - - const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); - - const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); - - FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", - neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); - - // ======== Build context ======== - struct hmx_fa_context factx; - memset(&factx, 0, sizeof(factx)); - factx.octx = octx; - factx.n_threads = octx->ctx->n_threads; - factx.DK = DK; - factx.DV = DV; - factx.n_kv = nek1; - factx.n_kv_heads = n_kv_heads; - factx.n_heads = neq2; - factx.G = G; - factx.neq1 = neq1; - factx.Br = (uint32_t) Br; - factx.Bc = (uint32_t) Bc; - factx.g_br = (uint32_t) g_br; - factx.n_kv_blocks = n_kv_blocks; - factx.is_q_fp32 = (q->type == HTP_TYPE_F32); - factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); - factx.use_pipeline = use_pipeline; - factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); - - // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) - float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; - memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - -#ifdef HMX_FA_USE_EXP2_HF - // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled - // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), - // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. - // - // When softcap is active we cannot pre-bake log2(e) here — it would land inside - // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving - // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh - // multiplier (see softcap block: v_cap absorbs log2(e)). - if (logit_softcap == 0.0f) { - scale *= 1.44269504f; // log2(e) - } -#endif - - factx.scale = scale; - factx.max_bias = max_bias; - factx.logit_softcap = logit_softcap; - - factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); - factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); - factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); - - // ======== VTCM allocation (GQA-aware) ======== - const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); - const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); - const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); - const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); - const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); - const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); - const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); - const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); - const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); - const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); - const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); - - uint8_t * vtcm_cur = ctx->vtcm_base; - - factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); - factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); - factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); - factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); - factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); - factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); - factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); - factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); - factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); - factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); - factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); - factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); - factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); - factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); - factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); - factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); - factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); - factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); - factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); - factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); - factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); - factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); - - if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { - return HTP_STATUS_VTCM_TOO_SMALL; - } - - // ======== Initialize HMX output scales ======== - // Identity scale (1.0) for O updates and normalization - hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 - - // QK scale embedded in HMX output - hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); - - // ======== Skip compute if profiling ======== - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { - return HTP_STATUS_OK; - } - - // Profiling timers - TIMER_DEFINE(total); - TIMER_DEFINE(q_load); - TIMER_DEFINE(kv_dma); - TIMER_DEFINE(k_interleave); - TIMER_DEFINE(v_interleave); - TIMER_DEFINE(qk_dot); - TIMER_DEFINE(softmax); - TIMER_DEFINE(o_update); - TIMER_DEFINE(o_norm); - TIMER_DEFINE(o_store); - - TIMER_START(total); - - // ======== DMA setup ======== - dma_queue * const dma = ctx->dma[0]; - - // Padded row sizes for DMA - const size_t size_k_row = nek0 * sizeof(__fp16); - const size_t size_v_row = nev0 * sizeof(__fp16); - const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); - const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); - - const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; - const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; - - // Q/O element size for Q load and O store - const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); - - // ======== HMX lock strategy ======== - // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. - // Fallback: main thread holds the lock (original behavior). - if (!factx.use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - } - - // ======== Reusable job descriptors for pipeline ======== - hmx_fa_qk_job_t qk_job; - hmx_fa_o_update_job_t ou_job; - hmx_fa_o_norm_job_t on_job; - - // ======== Main loop: per batch, per KV head, per Q block ======== - for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { - for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { - const uint32_t ik2 = kv_head; - const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); - const uint32_t iv2 = kv_head; - const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); - - for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { - const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); - const size_t n_rows_g = n_q_rows * G; - const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); - const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; - - // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- - TIMER_START(q_load); - if (n_rows_g < g_br) { - hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); - } - fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); - TIMER_STOP(q_load); - - // ---- Initialize per-block state ---- - hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); - hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); - hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); - - __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; - __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; - hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); - - // ---- KV block loop with DMA double-buffering ---- - size_t buf_idx = 0; - - // Prefetch first KV block - if (factx.n_kv_blocks > 0) { - const uint32_t kv_rows0 = hex_smin(Bc, nek1); - - const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; - dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], - size_k_row, kv_rows0); - - const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; - dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], - size_v_row, kv_rows0); - } - - // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. - // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. - #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ - do { \ - has_mask_dma_var = false; \ - if (mask && factx.mask_broadcast) { \ - const uint32_t _im3 = ib3 % mask->ne[3]; \ - const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ - (kv_start_val) * sizeof(__fp16); \ - dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ - (kv_rows_val) * sizeof(__fp16), n_q_rows); \ - has_mask_dma_var = true; \ - } \ - } while (0) - - #define MASK_DMA_POP(has_mask_dma_var) \ - do { \ - if (has_mask_dma_var) { \ - dma_queue_pop(dma); \ - } \ - } while (0) - - #define DMA_PREFETCH_KV(blk_val) \ - do { \ - if ((blk_val) < factx.n_kv_blocks) { \ - const uint32_t _ns = (blk_val) * Bc; \ - const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ - size_t _nb = 1 - buf_idx; \ - const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ - dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ - const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ - dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ - } \ - } while (0) - - const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); - const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); - - if (factx.use_pipeline) { - // ================================================================== - // Pipeline path: HVX phases ‖ HMX queue worker - // ================================================================== - struct hmx_queue * hmx_q = ctx->hmx_queue; - - for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { - const uint32_t kv_start = kv_blk * Bc; - const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); - const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); - - // Wait for current KV DMA - TIMER_START(kv_dma); - dma_queue_pop(dma); // K - dma_queue_pop(dma); // V - TIMER_STOP(kv_dma); - - // Push mask DMA for this block (single 2D DMA when broadcast) - bool has_mask_dma = false; - MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); - - // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- - if (kv_blk > 0) { - // Submit O_update for previous block (HMX worker) - ou_job.o_curr = o_tile_curr; - ou_job.o_prev = o_tile_prev; - ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; - ou_job.d_tiles = factx.vtcm_d_tiles; - ou_job.hmx_scales = factx.vtcm_hmx_scales_id; - ou_job.n_row_tiles = n_row_tiles; - ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); - ou_job.n_row_tiles_g_br = n_row_tiles_g_br; - ou_job.n_tiles_per_bc = n_tiles_per_bc; - ou_job.DV = DV; - hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); - } - - TIMER_START(k_interleave); - fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); - TIMER_STOP(k_interleave); - - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- - qk_job.q_tiles = factx.vtcm_q_tiles; - qk_job.k_tiles = factx.vtcm_k_tiles; - qk_job.s_tiles = factx.vtcm_s_tiles; - qk_job.n_row_tiles = n_row_tiles; - qk_job.n_col_tiles = n_col_tiles; - qk_job.n_dot_tiles = DK / 32; - qk_job.n_tiles_per_bc = n_tiles_per_bc; - qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; - TIMER_START(qk_dot); - hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); - - // DMA push next block (non-blocking, before worker_pool) - DMA_PREFETCH_KV(kv_blk + 1); - - TIMER_START(v_interleave); - fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); - TIMER_STOP(v_interleave); - - hmx_queue_pop(hmx_q); - TIMER_STOP(qk_dot); - - // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- - // Pop mask DMA before softmax (ensures VTCM buffer is ready) - MASK_DMA_POP(has_mask_dma); - - fa_softmax_args_t sargs; - memset(&sargs, 0, sizeof(sargs)); - sargs.factx = &factx; - sargs.kv_rows = kv_rows; - sargs.n_rows_g = n_rows_g; - sargs.n_col_tiles = n_col_tiles; - sargs.n_tiles_per_bc = n_tiles_per_bc; - sargs.n_row_tiles = n_row_tiles; - sargs.n_row_tiles_g_br = n_row_tiles_g_br; - sargs.Bc = Bc; - sargs.G = G; - sargs.kv_head = kv_head; - sargs.kv_start = kv_start; - sargs.q_start = q_start; - sargs.ib3 = ib3; - sargs.has_alibi = (factx.max_bias != 0.0f); - sargs.mask = mask; - sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; - sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; - sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); - - TIMER_START(softmax); - fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); - TIMER_STOP(softmax); - - buf_idx = 1 - buf_idx; - } // end KV block loop (pipeline) - - // Epilogue: O_update for last block - if (factx.n_kv_blocks > 0) { - const uint32_t last_blk = factx.n_kv_blocks - 1; - const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); - ou_job.o_curr = o_tile_curr; - ou_job.o_prev = o_tile_prev; - ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; - ou_job.d_tiles = factx.vtcm_d_tiles; - ou_job.hmx_scales = factx.vtcm_hmx_scales_id; - ou_job.n_row_tiles = n_row_tiles; - ou_job.n_col_tiles = last_cols; - ou_job.n_row_tiles_g_br = n_row_tiles_g_br; - ou_job.n_tiles_per_bc = n_tiles_per_bc; - ou_job.DV = DV; - - TIMER_START(o_update); - hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); - hmx_queue_pop(hmx_q); - TIMER_STOP(o_update); - - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - - } else { - // ================================================================== - // Fallback path: sequential with multi-thread HVX phases - // Main thread holds HMX lock, runs HMX inline. - // ================================================================== - - for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { - const uint32_t kv_start = kv_blk * Bc; - const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); - const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); - - TIMER_START(kv_dma); - dma_queue_pop(dma); // K - dma_queue_pop(dma); // V - TIMER_STOP(kv_dma); - - bool has_mask_dma = false; - MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); - DMA_PREFETCH_KV(kv_blk + 1); - - // K interleave (multi-thread HVX) - TIMER_START(k_interleave); - fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); - TIMER_STOP(k_interleave); - - // QK dot (inline HMX on main thread) - TIMER_START(qk_dot); - { - const size_t n_dot_tiles = (size_t) (DK / 32); - const __fp16 * restrict q_base = factx.vtcm_q_tiles; - const __fp16 * restrict k_base = factx.vtcm_k_tiles; - __fp16 * restrict s_base = factx.vtcm_s_tiles; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < n_col_tiles; ++c) { - const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; - const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; - __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; - for (size_t k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(out_tile, 0); - } - } - } - TIMER_STOP(qk_dot); - - // Pop mask DMA - MASK_DMA_POP(has_mask_dma); - - // Softmax + build_D (multi-thread HVX + serial m/l update) - fa_softmax_args_t sargs; - memset(&sargs, 0, sizeof(sargs)); - sargs.factx = &factx; - sargs.kv_rows = kv_rows; - sargs.n_rows_g = n_rows_g; - sargs.n_col_tiles = n_col_tiles; - sargs.n_tiles_per_bc = n_tiles_per_bc; - sargs.n_row_tiles = n_row_tiles; - sargs.n_row_tiles_g_br = n_row_tiles_g_br; - sargs.Bc = Bc; - sargs.G = G; - sargs.kv_head = kv_head; - sargs.kv_start = kv_start; - sargs.q_start = q_start; - sargs.ib3 = ib3; - sargs.has_alibi = (factx.max_bias != 0.0f); - sargs.mask = mask; - sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; - sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; - sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); - - TIMER_START(softmax); - fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); - TIMER_STOP(softmax); - - // V interleave (multi-thread HVX) - TIMER_START(v_interleave); - // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout - // stride to match o_update's v_tile access. Using per-block n_col_tiles - // misplaces DV_tile 1..3 in the last partial KV block. - fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); - TIMER_STOP(v_interleave); - - // O update (inline HMX on main thread) - TIMER_START(o_update); - { - const size_t DV_tiles = (size_t) (DV / 32); - const __fp16 * restrict d_base = factx.vtcm_d_tiles; - const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; - const __fp16 * restrict op_base = o_tile_prev; - __fp16 * restrict oc_base = o_tile_curr; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(DV_tiles > 0); - - Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < DV_tiles; ++c) { - const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); - - const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; - const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; - for (size_t k = 0; k < n_col_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); - p_tile_in += HMX_FP16_TILE_N_ELMS; - v_tile_in += HMX_FP16_TILE_N_ELMS; - } - - __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - Q6_mxmem_AR_after_hf(o_tile_out, 0); - } - } - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - TIMER_STOP(o_update); - - buf_idx = 1 - buf_idx; - } // end KV block loop (fallback) - } - - // ---- Final normalization: O = diag(1/l) @ O ---- - TIMER_START(o_norm); - { - fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); - - // HMX: O_final = diag(1/l) @ O_prev - if (factx.use_pipeline) { - on_job.o_curr = o_tile_curr; - on_job.o_prev = o_tile_prev; - on_job.d_tiles = factx.vtcm_d_tiles; - on_job.hmx_scales = factx.vtcm_hmx_scales_id; - on_job.n_row_tiles = n_row_tiles; - on_job.n_row_tiles_g_br = n_row_tiles_g_br; - on_job.DV = DV; - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); - hmx_queue_pop(ctx->hmx_queue); - } else { - const size_t DV_tiles = (size_t) (DV / 32); - const __fp16 * restrict d_base = factx.vtcm_d_tiles; - const __fp16 * restrict op_base = o_tile_prev; - __fp16 * restrict oc_base = o_tile_curr; - __builtin_assume(n_row_tiles > 0); - __builtin_assume(DV_tiles > 0); - - Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); - for (size_t r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < DV_tiles; ++c) { - const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; - const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; - __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; - - Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); - Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); - Q6_mxmem_AR_after_hf(o_out, 0); - } - } - } - } - TIMER_STOP(o_norm); - - // ---- Store O block ---- - TIMER_START(o_store); - fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); - TIMER_STOP(o_store); - -#undef MASK_DMA_PUSH -#undef MASK_DMA_POP -#undef DMA_PREFETCH_KV - - } // end Q block loop - } // end KV head loop - } // end batch loop - - if (factx.use_pipeline) { - hmx_queue_suspend(ctx->hmx_queue); - } else { - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), - TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); - FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), - TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); -#endif - - return HTP_STATUS_OK; -} diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h deleted file mode 100644 index b129fb74e..000000000 --- a/ggml/src/ggml-hexagon/htp/vtcm-utils.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef VTCM_UTILS_H -#define VTCM_UTILS_H - -#include "hex-utils.h" - -#include -#include -#include - -static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { - uint8_t *p = *vtcm_ptr; - *vtcm_ptr += size; - return p; -} - -#endif // VTCM_UTILS_H diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl deleted file mode 100644 index e404f392b..000000000 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +++ /dev/null @@ -1,302 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable - -#define TILESIZE_K 16 -#define TILESIZE_M 64 -#define TILESIZE_N 32 - - -static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { - ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; - fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; - fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; - fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; - fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; - - bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; - bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; - bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; - bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; - - fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; - fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; - fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; - fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; - - sign_a.lo = (fp4x8.s0 << 12) & 0x8000; - sign_a.hi = (fp4x8.s0 << 8) & 0x8000; - sign_b.lo = (fp4x8.s0 << 4) & 0x8000; - sign_b.hi = fp4x8.s0 & 0x8000; - - fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; - fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; - - ushort2 fp16_packed_a_1, fp16_packed_b_1; - fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; - fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; - fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; - fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; - - bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; - bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; - bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; - bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; - - fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; - fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; - fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; - fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; - - sign_a.lo = (fp4x8.s1 << 12) & 0x8000; - sign_a.hi = (fp4x8.s1 << 8) & 0x8000; - sign_b.lo = (fp4x8.s1 << 4) & 0x8000; - sign_b.hi = fp4x8.s1 & 0x8000; - - fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; - fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; - - return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); -} - - -#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ - acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ - acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ - acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ - acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ - acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ - acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ - acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ - acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ - acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ - acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ - acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ - acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ - acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ - acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ - acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ - acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ - acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ - acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ - acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ - acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ - acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ - acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ - acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ - acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ - acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ - acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ - acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ - acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ - acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ - acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ - acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ - acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ - c_reg.lo += convert_float8(acc.lo); \ - c_reg.hi += convert_float8(acc.hi); \ - acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ - acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ - acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ - acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ - acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ - acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ - acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ - acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ - acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ - acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ - acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ - acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ - acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ - acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ - acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ - acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ - acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ - acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ - acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ - acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ - acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ - acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ - acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ - acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ - acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ - acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ - acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ - acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ - acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ - acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ - acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ - acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ - c_reg.lo += convert_float8(acc.lo); \ - c_reg.hi += convert_float8(acc.hi); \ - - -static inline half e8m0_to_fp16(uchar x) { - ushort bits; - bits = (ushort)(x) - (ushort)(112); - bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10); - return as_half(bits); -} - -static inline float e8m0_to_fp32(uchar x) { - int bits; - bits = (x == 0) ? 0x00400000 : ((uint) x << 23); - return as_float(bits); -} - - -__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair -kernel void kernel_gemm_moe_mxfp4_f32_ns( - __read_only image1d_buffer_t src0_q, - __global uchar * src0_d, - __read_only image1d_buffer_t src1, - __global uint * src2, - __global ushort * src2_emap, - __write_only image1d_buffer_t dst, - __global int * total_tiles, - uint ne00, - uint ne01 -) { - uint block_id_m = get_global_id(1); // m_tile - uint block_id_n = get_global_id(2); // n_tile - - // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { - return; - } - - __private half16 reg_a; - __private float32 reg_c = (float32)(0); - __local half4 shared_b[128]; - - const ushort expert_id = src2_emap[block_id_n]; - - const uint row = block_id_m * TILESIZE_M; - const uint col = block_id_n * TILESIZE_N; - - uint sub_block_id_m = get_local_id(0); - uint2 b_global_offset; - b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; - b_global_offset.y = b_global_offset.x + (16 * ne00); - uint2 b_local_offset; - b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); - b_local_offset.y = b_local_offset.x + 16; - - // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks - for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { - // First sub-block - uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); - uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); - uint b_sub_offset = col * ne00 + step; - - // Load scale for current mxfp4 block - uint s_offset = s_sub_offset + get_global_id(0); - float s = e8m0_to_fp32(src0_d[s_offset]); - - // Load 16 fp4 (64-bits) in transposed layout - uint2 mxfp4x16; - mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; - mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; - - // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements - float8 bx8_f32; - bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); - bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); - // Convert to half and store to LM to share within the subgroup - half8 bx8_f16 = convert_half8(bx8_f32); - shared_b[b_local_offset.x] = bx8_f16.lo; - shared_b[b_local_offset.y] = bx8_f16.hi; - - // Dequantization - reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; - reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; - - sub_group_barrier(CLK_LOCAL_MEM_FENCE); - - // 32 16x16 fp16 dot product with 8 elements reduction for better precision - half16 acc; - dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); - dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); - - // Repeat for second sub-block - uint half_step = step + TILESIZE_K; - q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); - b_sub_offset = col * ne00 + half_step; - - // Load next 16 fp4 (64-bits) in transposed layout - mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; - mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; - - // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements - bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); - bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); - // Convert to half and store to LM to share within the subgroup - bx8_f16 = convert_half8(bx8_f32); - shared_b[b_local_offset.x] = bx8_f16.lo; - shared_b[b_local_offset.y] = bx8_f16.hi; - - // Dequantization - reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; - reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; - - sub_group_barrier(CLK_LOCAL_MEM_FENCE); - - // 32 16x16 fp16 dot product with 3-levels reduction for better precision - dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); - dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); - } - - // Load poster router and share in LM - __local uint out_idx[TILESIZE_N]; - - if (get_local_id(0) < TILESIZE_N) { - uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; - if (idx == 0xFFFFFFFF) { - idx = src2[block_id_n * TILESIZE_N + 0]; - } - out_idx[get_local_id(0)] = idx * ne01; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - // Scatter results back to original position in output grid - uint m_offset = row + get_local_id(0); - - write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); - write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); - write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); - write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); - write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); - write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); - write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); - write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); - write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); - write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); - write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); - write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); - write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); - write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); - write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); - write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); - write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); - write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); - write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); - write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); - write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); - write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); - write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); - write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); - write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); - write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); - write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); - write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); - write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); - write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); - write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); - - // Store zero padding parts to the index of first output in tile, override correct result in the end - barrier(CLK_GLOBAL_MEM_FENCE); - write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); -} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl deleted file mode 100644 index 02290c17e..000000000 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +++ /dev/null @@ -1,252 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable -#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable -#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable - -#define TILESIZE_K 16 -#define TILESIZE_M 64 -#define TILESIZE_N 32 - - -#define dequantize_q4_0(q4, a_f16, scale) \ - a_f16.s0 = (half)((q4.s0 & 0x000F) - 8) * scale; \ - a_f16.s1 = (half)(((q4.s0 & 0x00F0) >> 4) - 8) * scale; \ - a_f16.s2 = (half)(((q4.s0 & 0x0F00) >> 8) - 8) * scale; \ - a_f16.s3 = (half)(((q4.s0 & 0xF000) >> 12) - 8) * scale; \ - a_f16.s4 = (half)((q4.s1 & 0x000F) - 8) * scale; \ - a_f16.s5 = (half)(((q4.s1 & 0x00F0) >> 4) - 8) * scale; \ - a_f16.s6 = (half)(((q4.s1 & 0x0F00) >> 8) - 8) * scale; \ - a_f16.s7 = (half)(((q4.s1 & 0xF000) >> 12) - 8) * scale; \ - a_f16.s8 = (half)((q4.s2 & 0x000F) - 8) * scale; \ - a_f16.s9 = (half)(((q4.s2 & 0x00F0) >> 4) - 8) * scale; \ - a_f16.sa = (half)(((q4.s2 & 0x0F00) >> 8) - 8) * scale; \ - a_f16.sb = (half)(((q4.s2 & 0xF000) >> 12) - 8) * scale; \ - a_f16.sc = (half)((q4.s3 & 0x000F) - 8) * scale; \ - a_f16.sd = (half)(((q4.s3 & 0x00F0) >> 4) - 8) * scale; \ - a_f16.se = (half)(((q4.s3 & 0x0F00) >> 8) - 8) * scale; \ - a_f16.sf = (half)(((q4.s3 & 0xF000) >> 12) - 8) * scale; \ - - -#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ - acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ - acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ - acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ - acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ - acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ - acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ - acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ - acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ - acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ - acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ - acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ - acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ - acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ - acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ - acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ - acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ - acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ - acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ - acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ - acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ - acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ - acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ - acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ - acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ - acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ - acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ - acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ - acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ - acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ - acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ - acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ - acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ - c_reg.lo += convert_float8(acc.lo); \ - c_reg.hi += convert_float8(acc.hi); \ - acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ - acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ - acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ - acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ - acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ - acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ - acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ - acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ - acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ - acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ - acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ - acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ - acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ - acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ - acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ - acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ - acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ - acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ - acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ - acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ - acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ - acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ - acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ - acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ - acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ - acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ - acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ - acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ - acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ - acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ - acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ - acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ - c_reg.lo += convert_float8(acc.lo); \ - c_reg.hi += convert_float8(acc.hi); \ - - -__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair -kernel void kernel_gemm_moe_q4_0_f32_ns( - __read_only image1d_buffer_t src0_q, - __global half * src0_d, - __read_only image1d_buffer_t src1, - __global uint * src2, - __global ushort * src2_emap, - __write_only image1d_buffer_t dst, - __global int * total_tiles, - uint ne00, - uint ne01 -) { - uint block_id_m = get_global_id(1); // m_tile - uint block_id_n = get_global_id(2); // n_tile - - // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { - return; - } - - __private half16 reg_a; - __private float32 reg_c = (float32)(0); - __local half4 shared_b[128]; - - const ushort expert_id = src2_emap[block_id_n]; - - const uint row = block_id_m * TILESIZE_M; - const uint col = block_id_n * TILESIZE_N; - - uint sub_block_id_m = get_local_id(0); - uint2 b_global_offset; - b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; - b_global_offset.y = b_global_offset.x + (16 * ne00); - uint2 b_local_offset; - b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); - b_local_offset.y = b_local_offset.x + 16; - - // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks - for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { - // First sub-block - uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); - uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); - uint b_sub_offset = col * ne00 + step; - - // Load scale for current Q4_0 block - uint s_offset = s_sub_offset + get_global_id(0); - half s = src0_d[s_offset]; - - // Load 16 q (64-bits) in transposed layout - uint2 q4x16; - q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; - q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; - - // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements - float8 bx8_f32; - bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); - bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); - // Convert to half and store to LM to share within the subgroup - half8 bx8_f16 = convert_half8(bx8_f32); - shared_b[b_local_offset.x] = bx8_f16.lo; - shared_b[b_local_offset.y] = bx8_f16.hi; - - // Dequantization - dequantize_q4_0(as_ushort4(q4x16), reg_a, s); - - sub_group_barrier(CLK_LOCAL_MEM_FENCE); - - // 32 16x16 fp16 dot product with 8 elements reduction for better precision - half16 acc; - dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); - dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); - - // Repeat for second sub-block - uint half_step = step + TILESIZE_K; - q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); - b_sub_offset = col * ne00 + half_step; - - // Load next 16 q (64-bits) in transposed layout - q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; - q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; - - // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements - bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); - bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); - // Convert to half and store to LM to share within the subgroup - bx8_f16 = convert_half8(bx8_f32); - shared_b[b_local_offset.x] = bx8_f16.lo; - shared_b[b_local_offset.y] = bx8_f16.hi; - - // Dequantization - dequantize_q4_0(as_ushort4(q4x16), reg_a, s); - - sub_group_barrier(CLK_LOCAL_MEM_FENCE); - - // 32 16x16 fp16 dot product with 3-levels reduction for better precision - dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); - dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); - } - - // Load poster router and share in LM - __local uint out_idx[TILESIZE_N]; - - if (get_local_id(0) < TILESIZE_N) { - uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; - if (idx == 0xFFFFFFFF) { - idx = src2[block_id_n * TILESIZE_N + 0]; - } - out_idx[get_local_id(0)] = idx * ne01; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - // Scatter results back to original position in output grid - uint m_offset = row + get_local_id(0); - - write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); - write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); - write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); - write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); - write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); - write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); - write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); - write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); - write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); - write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); - write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); - write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); - write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); - write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); - write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); - write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); - write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); - write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); - write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); - write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); - write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); - write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); - write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); - write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); - write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); - write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); - write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); - write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); - write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); - write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); - write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); - - // Store zero padding parts to the index of first output in tile, override correct result in the end - barrier(CLK_GLOBAL_MEM_FENCE); - write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); -} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl deleted file mode 100644 index e4b44c1a5..000000000 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +++ /dev/null @@ -1,161 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable - -#define QK_MXFP4 32 -#define N_SIMDGROUP 4 -#define SIMDGROUP_WIDTH 64 - -static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { - ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; - fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; - fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; - fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; - fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; - - bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; - bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; - bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; - bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; - - fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; - fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; - fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; - fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; - - sign_a.lo = (fp4x8.s0 << 12) & 0x8000; - sign_a.hi = (fp4x8.s0 << 8) & 0x8000; - sign_b.lo = (fp4x8.s0 << 4) & 0x8000; - sign_b.hi = fp4x8.s0 & 0x8000; - - fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; - fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; - - ushort2 fp16_packed_a_1, fp16_packed_b_1; - fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; - fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; - fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; - fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; - - bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; - bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; - bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; - bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; - - fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; - fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; - fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; - fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; - - sign_a.lo = (fp4x8.s1 << 12) & 0x8000; - sign_a.hi = (fp4x8.s1 << 8) & 0x8000; - sign_b.lo = (fp4x8.s1 << 4) & 0x8000; - sign_b.hi = fp4x8.s1 & 0x8000; - - fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; - fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; - - return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); -} - -static inline float e8m0_to_fp32(uchar x) { - int bits; - bits = (x == 0) ? 0x00400000 : ((uint) x << 23); - return as_float(bits); -} - - -__attribute__((qcom_reqd_sub_group_size("half"))) -__kernel void kernel_gemv_moe_mxfp4_f32_ns( - __global uint * src0_q, - __global uchar * src0_e, - __read_only image1d_buffer_t src1, - __global uint * src2, - __global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne11 -) { - uint i01 = get_global_id(0); - uint i20 = get_global_id(2); - uint sgid = get_local_id(1); - uint slid = get_sub_group_local_id(); - - uint i11 = i20 % ne11; - - uint expert_id = src2[i20]; - uint expert_offset = expert_id * ne00 * ne01 / 32; - - __private float sum = 0.0f; // each thread calculate partial sum of one output - - // loop along ne00 in block granularity, skip 4 blocks every iter - for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { - - // load one block of q - uint4 regQ; - uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; - - regQ.s0 = src0_q[block_offset]; - regQ.s1 = src0_q[block_offset + ne01]; - regQ.s2 = src0_q[block_offset + ne01 * 2]; - regQ.s3 = src0_q[block_offset + ne01 * 3]; - - uint offset = i11 * ne00 / 4 + ib00 * 8; - - half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); - - float4 shared_y4; - shared_y4 = read_imagef(src1, (offset + 0)); - float4 acc = shared_y4 * convert_float4(fp16x8.lo); - - shared_y4 = read_imagef(src1, (offset + 1)); - acc += shared_y4 * convert_float4(fp16x8.hi); - - fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); - - shared_y4 = read_imagef(src1, (offset + 2)); - acc += shared_y4 * convert_float4(fp16x8.lo); - - shared_y4 = read_imagef(src1, (offset + 3)); - acc += shared_y4 * convert_float4(fp16x8.hi); - - - fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); - - shared_y4 = read_imagef(src1, (offset + 4)); - acc += shared_y4 * convert_float4(fp16x8.lo); - - shared_y4 = read_imagef(src1, (offset + 5)); - acc += shared_y4 * convert_float4(fp16x8.hi); - - - fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); - - shared_y4 = read_imagef(src1, (offset + 6)); - acc += shared_y4 * convert_float4(fp16x8.lo); - - shared_y4 = read_imagef(src1, (offset + 7)); - acc += shared_y4 * convert_float4(fp16x8.hi); - - uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; - sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); - } - - // reduction in local memory, assumes #subgroups=4 - __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; - if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; - if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; - if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; - - // 1 outputs per thread in subgroup 0 - if (sgid == 0) { - dst = dst + (offsetd >> 2); - dst[i01 + i20 * ne01] = sum; - } - -} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl deleted file mode 100644 index 6f4d3f532..000000000 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +++ /dev/null @@ -1,116 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable - -#define QK_Q4_0 32 -#define N_SIMDGROUP 4 -#define SIMDGROUP_WIDTH 64 - -static inline float8 q4_0_to_fp32_packed8(ushort2 q4x8) { - float8 fp32x8; - fp32x8.s0 = (float)((q4x8.s0 & 0x000F) - 8); - fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) - 8); - fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) - 8); - fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) - 8); - fp32x8.s4 = (float)((q4x8.s1 & 0x000F) - 8); - fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) - 8); - fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) - 8); - fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) - 8); - return fp32x8; -} - - -__attribute__((qcom_reqd_sub_group_size("half"))) -__kernel void kernel_gemv_moe_q4_0_f32_ns( - __global uint * src0_q, - __global half * src0_d, - __read_only image1d_buffer_t src1, - __global uint * src2, - __global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne11 -) { - uint i01 = get_global_id(0); - uint i20 = get_global_id(2); - uint sgid = get_local_id(1); - uint slid = get_sub_group_local_id(); - - uint i11 = i20 % ne11; - - uint expert_id = src2[i20]; - uint expert_offset = expert_id * ne00 * ne01 / 32; - - __private float sum = 0.0f; // each thread calculate partial sum of one output - - // loop along ne00 in block granularity, skip 4 blocks every iter - for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_0); ib00 += N_SIMDGROUP) { - - // load one block of q - uint4 regQ; - uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; - - regQ.s0 = src0_q[block_offset]; - regQ.s1 = src0_q[block_offset + ne01]; - regQ.s2 = src0_q[block_offset + ne01 * 2]; - regQ.s3 = src0_q[block_offset + ne01 * 3]; - - uint offset = i11 * ne00 / 4 + ib00 * 8; - - float8 fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s0)); - - float4 shared_y4; - shared_y4 = read_imagef(src1, (offset + 0)); - float4 acc = shared_y4 * fp32x8.lo; - - shared_y4 = read_imagef(src1, (offset + 1)); - acc += shared_y4 * fp32x8.hi; - - fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s1)); - - shared_y4 = read_imagef(src1, (offset + 2)); - acc += shared_y4 * fp32x8.lo; - - shared_y4 = read_imagef(src1, (offset + 3)); - acc += shared_y4 * fp32x8.hi; - - - fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s2)); - - shared_y4 = read_imagef(src1, (offset + 4)); - acc += shared_y4 * fp32x8.lo; - - shared_y4 = read_imagef(src1, (offset + 5)); - acc += shared_y4 * fp32x8.hi; - - - fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s3)); - - shared_y4 = read_imagef(src1, (offset + 6)); - acc += shared_y4 * fp32x8.lo; - - shared_y4 = read_imagef(src1, (offset + 7)); - acc += shared_y4 * fp32x8.hi; - - half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; - sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); - } - - // reduction in local memory, assumes #subgroups=4 - __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; - if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; - if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; - if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; - barrier(CLK_LOCAL_MEM_FENCE); - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; - if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; - - // 1 outputs per thread in subgroup 0 - if (sgid == 0) { - dst = dst + (offsetd >> 2); - dst[i01 + i20 * ne01] = sum; - } - -} diff --git a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl deleted file mode 100644 index e6295c816..000000000 --- a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +++ /dev/null @@ -1,30 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -#define QK4_0 32 - -kernel void kernel_moe_reorder_b( - global float4 * src, - global uint * router, - global float4 * dst, - global int * total_tiles, - uint K, - ushort map_ratio, - uint tile_size -) { - uint k_4 = get_global_id(0); - uint post_router_idx = get_global_id(1); - - if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) { - return; - } - - uint router_idx = router[post_router_idx]; - - float4 out = (float4)(0); - if (router_idx != 0xFFFFFFFF) { - ushort activation_idx = router_idx / map_ratio; - out = src[activation_idx * K / 4 + k_4]; - } - - dst[post_router_idx * K / 4 + k_4] = out; -} diff --git a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl deleted file mode 100644 index d9703429b..000000000 --- a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +++ /dev/null @@ -1,82 +0,0 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -__kernel void kernel_moe_histogram( - __global const int * input, - __global int * hist, - uint N, - uint topK, - uint n_experts -) { - uint n = get_global_id(0); - uint k = get_global_id(1); - - if (n >= N || k >= topK) { - return; - } - - int expert_id = input[n * n_experts + k]; - atomic_inc(&hist[expert_id]); -} - -__kernel void kernel_moe_scan( - __global int * hist, - __global int * tile_offset, - __global int * total_tiles, - __global int * slot_counter, - int tile_size, - uint n_experts -) { - int offset = 0; - for (int v = 0; v < n_experts; v++) { - int count = hist[v]; - int tiles = (count + tile_size - 1) / tile_size; - tile_offset[v] = offset; - offset += tiles; - hist[v] = 0; - slot_counter[v] = 0; - } - - *total_tiles = offset; -} - -__kernel void kernel_moe_scatter( - __global const int * input, - __global int * post_router, - __global ushort * emap, - __global const int * tile_offset, - __global int * slot_counter, - int N, - int topK, - uint n_experts -) { - uint n = get_global_id(0); - uint k = get_global_id(1); - - if (n >= N || k >= topK) { - return; - } - - int val = input[n * n_experts + k]; - - int local_slot = atomic_inc(&slot_counter[val]); - - int tile_idx = tile_offset[val] + (local_slot / 32); - int lane = local_slot % 32; - int out_pos = tile_idx * 32 + lane; - - post_router[out_pos] = n * topK + k; - emap[tile_idx] = val; -} - -__kernel void kernel_moe_fill( - __global int * post_router, - __global int * total_tiles, - int tile_size -) { - int tile_id = get_global_id(0); - int vec_id_in_tile = get_global_id(1); - - if (tile_id < total_tiles[0]) { - post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF; - } -} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp new file mode 100644 index 000000000..1cb8f563d --- /dev/null +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -0,0 +1,1974 @@ +#include "ggml-rpc.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-cpp.h" +#include "transport.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + + +namespace fs = std::filesystem; + +// macro for nicer error messages on server crash +#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") + +// all RPC structures must be packed +#pragma pack(push, 1) +// ggml_tensor is serialized into rpc_tensor +struct rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); + +// RPC commands +enum rpc_cmd { + RPC_CMD_ALLOC_BUFFER = 0, + RPC_CMD_GET_ALIGNMENT, + RPC_CMD_GET_MAX_SIZE, + RPC_CMD_BUFFER_GET_BASE, + RPC_CMD_FREE_BUFFER, + RPC_CMD_BUFFER_CLEAR, + RPC_CMD_SET_TENSOR, + RPC_CMD_SET_TENSOR_HASH, + RPC_CMD_GET_TENSOR, + RPC_CMD_COPY_TENSOR, + RPC_CMD_GRAPH_COMPUTE, + RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, + RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_RECOMPUTE, + RPC_CMD_COUNT, +}; + +static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); + +// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold +const size_t HASH_THRESHOLD = 10 * 1024 * 1024; + +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + +struct rpc_msg_get_alloc_size_req { + uint32_t device; + rpc_tensor tensor; + rpc_tensor srcs[GGML_MAX_SRC]; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + +struct rpc_msg_alloc_buffer_req { + uint32_t device; + uint64_t size; +}; + +struct rpc_msg_alloc_buffer_rsp { + uint64_t remote_ptr; + uint64_t remote_size; +}; + +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + +struct rpc_msg_get_alignment_rsp { + uint64_t alignment; +}; + +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + +struct rpc_msg_get_max_size_rsp { + uint64_t max_size; +}; + +struct rpc_msg_buffer_get_base_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_get_base_rsp { + uint64_t base_ptr; +}; + +struct rpc_msg_free_buffer_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_clear_req { + uint64_t remote_ptr; + uint8_t value; +}; + +struct rpc_msg_set_tensor_hash_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t hash; +}; + +struct rpc_msg_set_tensor_hash_rsp { + uint8_t result; +}; + +struct rpc_msg_get_tensor_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t size; +}; + +struct rpc_msg_copy_tensor_req { + rpc_tensor src; + rpc_tensor dst; +}; + +struct rpc_msg_copy_tensor_rsp { + uint8_t result; +}; + +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + +struct rpc_msg_get_device_memory_rsp { + uint64_t free_mem; + uint64_t total_mem; +}; + +struct rpc_msg_graph_recompute_req { + uint32_t device; +}; + +#pragma pack(pop) + +// RPC data structures + +static ggml_guid_t ggml_backend_rpc_guid() { + static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; + return &guid; +} + +struct ggml_backend_rpc_buffer_type_context { + std::string endpoint; + uint32_t device; + std::string name; + size_t alignment; + size_t max_size; +}; + +struct ggml_backend_rpc_context { + std::string endpoint; + uint32_t device; + std::string name; + uint64_t last_graph_uid; +}; + +struct ggml_backend_rpc_buffer_context { + std::shared_ptr sock; + void * base_ptr; + uint64_t remote_ptr; +}; + +// RPC helper functions + +// Computes FNV-1a hash of the data +static uint64_t fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return hash; +} + +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { + return false; + } + return sock->send_data(msg, msg_size); +} + +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { + uint64_t size; + if (!sock->recv_data(&size, sizeof(size))) { + return false; + } + if (size != msg_size) { + return false; + } + return sock->recv_data(msg, msg_size); +} + +static bool recv_msg(socket_ptr sock, std::vector & input) { + uint64_t size; + if (!sock->recv_data(&size, sizeof(size))) { + return false; + } + try { + input.resize(size); + } catch (const std::bad_alloc & e) { + GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); + return false; + } + return sock->recv_data(input.data(), size); +} + +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); + if (pos == std::string::npos) { + return false; + } + host = endpoint.substr(0, pos); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// No response +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { + uint8_t cmd_byte = cmd; + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { + return false; + } + if (!sock->send_data(&input_size, sizeof(input_size))) { + return false; + } + if (!sock->send_data(input, input_size)) { + return false; + } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { + if (!send_rpc_cmd(sock, cmd, input, input_size)) { + return false; + } + uint64_t out_size; + if (!sock->recv_data(&out_size, sizeof(out_size))) { + return false; + } + if (out_size != output_size) { + return false; + } + if (!sock->recv_data(output, output_size)) { + return false; + } + return true; +} + +// RPC client-side implementation + +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); + return false; + } + + sock->update_caps(response.conn_caps); + return true; +} + +static std::shared_ptr get_socket(const std::string & endpoint) { + static std::mutex mutex; + std::lock_guard lock(mutex); + static std::unordered_map> sockets; + + auto it = sockets.find(endpoint); + if (it != sockets.end()) { + if (auto sock = it->second.lock()) { + return sock; + } + } + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); + return nullptr; + } + + if (!rpc_transport_init()) { + return nullptr; + } + auto sock = socket_t::connect(host.c_str(), port); + if (sock == nullptr) { + return nullptr; + } + if (!negotiate_hello(sock)) { + return nullptr; + } + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); + sockets[endpoint] = sock; + return sock; +} + +static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_free_buffer_req request = {ctx->remote_ptr}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + delete ctx; +} + +static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; + } + rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; + rpc_msg_buffer_get_base_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; +} + +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + +static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { + rpc_tensor result; + if (!tensor) { + memset(&result, 0, sizeof(result)); + return result; + } + + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) { + ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast(tensor->data); + } else { + result.buffer = 0; + result.data = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + + // Avoid sending uninitialized data over the wire + memset(result.name, 0, sizeof(result.name)); + memset(result.padding, 0, sizeof(result.padding)); + + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_tensor rpc_tensor = serialize_tensor(tensor); + if (size > HASH_THRESHOLD) { + rpc_msg_set_tensor_hash_req request; + request.tensor = rpc_tensor; + request.offset = offset; + request.hash = fnv_hash((const uint8_t*)data, size); + rpc_msg_set_tensor_hash_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.result) { + // the server has the same data, no need to send it + return; + } + } + // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) + size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; + std::vector input(input_size, 0); + memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); + memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); + memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()); + RPC_STATUS_ASSERT(status); +} + +static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_get_tensor_req request; + request.tensor = serialize_tensor(tensor); + request.offset = offset; + request.size = size; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); + RPC_STATUS_ASSERT(status); +} + +static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; + } + return false; +} + +static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); + RPC_STATUS_ASSERT(status); +} + +static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { + /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, + /* .get_base = */ ggml_backend_rpc_buffer_get_base, + /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, + /* .clear = */ ggml_backend_rpc_buffer_clear, + /* .reset = */ NULL, +}; + +static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; + rpc_msg_alloc_buffer_rsp response; + auto sock = get_socket(buft_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.remote_ptr != 0) { + ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, + ggml_backend_rpc_buffer_interface, + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, + response.remote_size); + return buffer; + } else { + return nullptr; + } +} + +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; + rpc_msg_get_alignment_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.alignment; +} + +static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->alignment; +} + +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; + rpc_msg_get_max_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.max_size; +} + +static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->max_size; +} + +static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + // should we query the remote server for the actual size + bool rpc_get = false; + + // See comments in init_tensor. + rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr); + + // ops that require additional memory for fleeting data on certain backends + // ref: https://github.com/ggml-org/llama.cpp/pull/15966 + rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT; + rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID; + + if (rpc_get) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request = { + /*.device =*/ buft_ctx->device, + /*.tensor =*/ serialize_tensor(tensor), + /*.srcs =*/ {}, + }; + + // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well + for (int i = 0; i < GGML_MAX_SRC; i++) { + request.srcs[i] = serialize_tensor(tensor->src[i]); + } + + // TODO: cache the alloc responses to avoid extra RPC calls? + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + + return response.alloc_size; + } + + return ggml_nbytes(tensor); +} + +static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { + /* .get_name = */ ggml_backend_rpc_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_rpc_get_max_size, + /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +static const char * ggml_backend_rpc_name(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + return rpc_ctx->name.c_str(); +} + +static void ggml_backend_rpc_free(ggml_backend_t backend) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + delete rpc_ctx; + delete backend; +} + +static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); + // this is no-op because we don't have any async operations +} + +static void add_tensor(ggml_tensor * tensor, std::vector & tensors, std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + add_tensor(tensor->src[i], tensors, visited); + } + add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(serialize_tensor(tensor)); +} + +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + output.resize(output_size, 0); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); +} + +static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + if (reuse) { + rpc_msg_graph_recompute_req request; + request.device = rpc_ctx->device; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + RPC_STATUS_ASSERT(status); + } else { + rpc_ctx->last_graph_uid = cgraph->uid; + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; +} + +static ggml_backend_i ggml_backend_rpc_interface = { + /* .get_name = */ ggml_backend_rpc_name, + /* .free = */ ggml_backend_rpc_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ ggml_backend_rpc_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_rpc_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; + // NOTE: buffer types are allocated and never freed; this is by design + static std::unordered_map buft_map; + auto it = buft_map.find(buft_name); + if (it != buft_map.end()) { + return it->second; + } + auto sock = get_socket(endpoint); + if (sock == nullptr) { + GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); + return nullptr; + } + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); + ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ buft_name, + /* .alignment = */ alignment, + /* .max_size = */ max_size + }; + auto reg = ggml_backend_rpc_add_server(endpoint); + ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { + /* .iface = */ ggml_backend_rpc_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(reg, device), + /* .context = */ buft_ctx + }; + buft_map[buft_name] = buft; + return buft; +} + +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; + ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .last_graph_uid = */ 0, + }; + auto reg = ggml_backend_rpc_add_server(endpoint); + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_rpc_guid(), + /* .iface = */ ggml_backend_rpc_interface, + /* .device = */ ggml_backend_reg_dev_get(reg, device), + /* .context = */ ctx + }; + return backend; +} + +bool ggml_backend_is_rpc(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); +} + +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; + rpc_msg_get_device_memory_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + *free = response.free_mem; + *total = response.total_mem; +} + +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { + *free = 0; + *total = 0; + return; + } + get_device_memory(sock, device, free, total); +} + +// RPC server-side implementation + +class rpc_server { +public: + rpc_server(std::vector all_backends, const char * cache_dir) + : backends(std::move(all_backends)), cache_dir(cache_dir) { + stored_graphs.resize(backends.size()); + } + ~rpc_server(); + + void hello(rpc_msg_hello_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); + bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); + bool free_buffer(const rpc_msg_free_buffer_req & request); + bool buffer_clear(const rpc_msg_buffer_clear_req & request); + bool set_tensor(const std::vector & input); + bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); + bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); + bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); + bool graph_compute(const std::vector & input); + bool graph_recompute(const rpc_msg_graph_recompute_req & request); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + + struct stored_graph { + std::vector buffer; + ggml_cgraph * graph; + }; + +private: + bool get_cached_file(uint64_t hash, std::vector & data); + ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); + ggml_tensor * create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map); + + + std::vector backends; + const char * cache_dir; + std::unordered_set buffers; + // store the last computed graph for each backend + std::vector stored_graphs; +}; + +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + return false; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (request.srcs[i].id != 0) { + tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]); + } + } + + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + } else { + buft = tensor->buffer->buft; + } + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor); + + return true; +} + +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); + response.remote_ptr = 0; + response.remote_size = 0; + if (buffer != nullptr) { + response.remote_ptr = reinterpret_cast(buffer); + response.remote_size = buffer->size; + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); + buffers.insert(buffer); + } else { + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); + } + return true; +} + +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + size_t alignment = ggml_backend_buft_get_alignment(buft); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); + response.alignment = alignment; + return true; +} + +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); + size_t max_size = ggml_backend_buft_get_max_size(buft); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); + response.max_size = max_size; + return true; +} + +bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + void * base = ggml_backend_buffer_get_base(buffer); + response.base_ptr = reinterpret_cast(base); + return true; +} + +bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_free(buffer); + buffers.erase(buffer); + return true; +} + +bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { + LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); + if (buffers.find(buffer) == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + return false; + } + ggml_backend_buffer_clear(buffer, request.value); + return true; +} + +ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { + // Validate tensor type before using it + if (tensor->type >= GGML_TYPE_COUNT) { + GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type); + return nullptr; + } + + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, + tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + + // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type + if (result == nullptr) { + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); + return nullptr; + } + + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + if (result->buffer && buffers.find(result->buffer) == buffers.end()) { + result->buffer = nullptr; + } + + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow + GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor->data); + ggml_set_name(result, tensor->name); + return result; +} + + +bool rpc_server::set_tensor(const std::vector & input) { + // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | + if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) { + return false; + } + const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); + uint64_t offset; + memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); + const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset); + + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, in_tensor->data, offset, size, p0, p1); + return false; + } + } + + const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); + if (cache_dir && size > HASH_THRESHOLD) { + uint64_t hash = fnv_hash((const uint8_t*)data, size); + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + // save to cache_dir/hash_str + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::ofstream ofs(cache_file, std::ios::binary); + ofs.write((const char *)data, size); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str()); + } + ggml_backend_tensor_set(tensor, data, offset, size); + return true; +} + +bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { + if (!cache_dir) { + return false; + } + char hash_str[17]; + snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash); + fs::path cache_file = fs::path(cache_dir) / hash_str; + std::error_code ec; + if (!fs::exists(cache_file, ec)) { + return false; + } + std::ifstream ifs(cache_file, std::ios::binary); + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + data.resize(size); + ifs.read((char *)data.data(), size); + return true; +} + +bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response) +{ + std::vector cached_file; + if (!get_cached_file(request.hash, cached_file)) { + response.result = 0; + return true; + } + size_t size = cached_file.size(); + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", + __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 + || request.tensor.data + request.offset >= p1 + || size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, size, request.hash, p0, p1); + return false; + } + } + ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size); + response.result = 1; + return true; +} + +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + return false; + } + + return true; +} + +bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr || tensor->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + return false; + } + LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); + + // sanitize tensor->data + { + const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); + const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); + + if (request.tensor.data + request.offset < p0 || + request.tensor.data + request.offset >= p1 || + request.size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, request.size, p0, p1); + return false; + } + } + + response.resize(request.size, 0); + ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); + return true; +} + +bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) { + struct ggml_init_params params { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * src = deserialize_tensor(ctx, &request.src); + ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); + if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); + return false; + } + + uint64_t src_size = (uint64_t) ggml_nbytes(src); + uint64_t dst_data = (uint64_t) dst->data; + uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer); + uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer); + + if (dst_data + src_size > dst_base + dst_buf_sz) { + GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n" + " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n" + " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n", + __func__, + dst_data, + dst_data + src_size, + dst_base, + dst_base + dst_buf_sz); + return false; + } + + LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n", + __func__, (void*) src->buffer, (void*) dst->buffer); + + response.result = ggml_backend_buffer_copy_tensor(src, dst); + return true; +} + +ggml_tensor * rpc_server::create_node(uint64_t id, + struct ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + // Safely find the tensor pointer + auto it_ptr = tensor_ptrs.find(id); + if (it_ptr == tensor_ptrs.end()) { + return nullptr; + } + const rpc_tensor * tensor = it_ptr->second; + + struct ggml_tensor * result = deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + // Check if the source ID is 0 before calling create_node recursively + if (tensor->src[i] == 0) { + result->src[i] = nullptr; + } else { + result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->src[i] == nullptr) { + GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, i, tensor->src[i], id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + } + + // Handle view_src similarly + if (tensor->view_src == 0) { + result->view_src = nullptr; + } else { + result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->view_src == nullptr) { + GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, tensor->view_src, id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + result->view_offs = tensor->view_offs; + return result; +} + +bool rpc_server::graph_compute(const std::vector & input) { + // serialization format: + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { + return false; + } + uint32_t n_nodes; + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + return false; + } + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); + uint32_t n_tensors; + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + return false; + } + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); + + size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + tensor_ptrs.reserve(n_tensors); + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs.emplace(tensors[i].id, &tensors[i]); + } + std::unordered_map tensor_map; + tensor_map.reserve(n_nodes); + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); + + // Check if create_node failed for a *non-zero* ID. + // If id was 0, create_node returning nullptr is expected. + // If id was non-zero and create_node returned nullptr, it indicates a deserialization error. + if (graph->nodes[i] == nullptr && id != 0) { + GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id); + return false; + } + } + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + stored_graphs[device].graph = graph; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { + uint32_t device = request.device; + if (device >= backends.size()) { + return false; + } + if (stored_graphs[device].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[device].graph; + LOG_DBG("[%s] device: %u\n", __func__, device); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + return true; +} + +bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + size_t free, total; + ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]); + ggml_backend_dev_memory(dev, &free, &total); + response.free_mem = free; + response.total_mem = total; + LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem); + return true; +} + +rpc_server::~rpc_server() { + for (auto buffer : buffers) { + ggml_backend_buffer_free(buffer); + } +} + +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + socket_ptr sock) { + rpc_server server(backends, cache_dir); + uint8_t cmd; + if (!sock->recv_data(&cmd, 1)) { + return; + } + if (cmd != RPC_CMD_HELLO) { + GGML_LOG_ERROR("Expected HELLO command, update client\n"); + return; + } + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { + return; + } + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); + return; + } + + rpc_msg_hello_req req = {}; + if (!sock->recv_data(&req, sizeof(req))) { + return; + } + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + // Advertise server transport capabilities based on client's caps + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { + return; + } + + // Activate transport upgrade using client's caps + sock->update_caps(req.conn_caps); + while (true) { + if (!sock->recv_data(&cmd, 1)) { + break; + } + if (cmd >= RPC_CMD_COUNT) { + // fail fast if the command is invalid + GGML_LOG_ERROR("Unknown command: %d\n", cmd); + break; + } + switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sock, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_ALLOC_BUFFER: { + rpc_msg_alloc_buffer_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_alloc_buffer_rsp response; + if (!server.alloc_buffer(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + if (!server.get_alloc_size(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALIGNMENT: { + rpc_msg_get_alignment_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_alignment_rsp response; + if (!server.get_alignment(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_MAX_SIZE: { + rpc_msg_get_max_size_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_max_size_rsp response; + if (!server.get_max_size(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_BUFFER_GET_BASE: { + rpc_msg_buffer_get_base_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_buffer_get_base_rsp response; + if (!server.buffer_get_base(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_FREE_BUFFER: { + rpc_msg_free_buffer_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.free_buffer(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_BUFFER_CLEAR: { + rpc_msg_buffer_clear_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.buffer_clear(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR: { + std::vector input; + if (!recv_msg(sock, input)) { + return; + } + if (!server.set_tensor(input)) { + return; + } + break; + } + case RPC_CMD_SET_TENSOR_HASH: { + rpc_msg_set_tensor_hash_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_set_tensor_hash_rsp response; + if (!server.set_tensor_hash(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sock, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sock, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_GET_TENSOR: { + rpc_msg_get_tensor_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + std::vector response; + if (!server.get_tensor(request, response)) { + return; + } + if (!send_msg(sock, response.data(), response.size())) { + return; + } + break; + } + case RPC_CMD_COPY_TENSOR: { + rpc_msg_copy_tensor_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_copy_tensor_rsp response; + if (!server.copy_tensor(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GRAPH_COMPUTE: { + std::vector input; + if (!recv_msg(sock, input)) { + return; + } + if (!server.graph_compute(input)) { + return; + } + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + if (!server.graph_recompute(request)) { + return; + } + break; + } + case RPC_CMD_GET_DEVICE_MEMORY: { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sock, &request, sizeof(request))) { + return; + } + rpc_msg_get_device_memory_rsp response; + if (!server.get_device_memory(request, response)) { + return; + } + if (!send_msg(sock, &response, sizeof(response))) { + return; + } + break; + } + default: { + GGML_LOG_ERROR("Unknown command: %d\n", cmd); + return; + } + } + } +} + +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) { + if (n_devices == 0 || devices == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + printf("Starting RPC server v%d.%d.%d\n", + RPC_PROTO_MAJOR_VERSION, + RPC_PROTO_MINOR_VERSION, + RPC_PROTO_PATCH_VERSION); + printf(" endpoint : %s\n", endpoint); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total / 1024 / 1024, free / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } + + std::string host; + int port; + if (!parse_endpoint(endpoint, host, port)) { + return; + } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; + } + auto server_socket = socket_t::create_server(host.c_str(), port); + if (server_socket == nullptr) { + fprintf(stderr, "Failed to create server socket\n"); + return; + } + while (true) { + auto client_socket = server_socket->accept(); + if (client_socket == nullptr) { + fprintf(stderr, "Failed to accept client connection\n"); + return; + } + printf("Accepted client connection\n"); + fflush(stdout); + rpc_serve_client(backends, cache_dir, client_socket); + printf("Client connection closed\n"); + fflush(stdout); + } + rpc_transport_shutdown(); + for (auto backend : backends) { + ggml_backend_free(backend); + } +} + +// device interface + +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { + // TODO: obtain value from the server + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_rpc_device_get_name(dev); + props->description = ggml_backend_rpc_device_get_description(dev); + props->type = ggml_backend_rpc_device_get_type(dev); + ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); + + GGML_UNUSED(dev); +} + +static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + GGML_UNUSED(op); + //TODO: call the remote backend and cache the results + return true; +} + +static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + return false; + } + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; +} + +static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { + /* .get_name = */ ggml_backend_rpc_device_get_name, + /* .get_description = */ ggml_backend_rpc_device_get_description, + /* .get_memory = */ ggml_backend_rpc_device_get_memory, + /* .get_type = */ ggml_backend_rpc_device_get_type, + /* .get_props = */ ggml_backend_rpc_device_get_props, + /* .init_backend = */ ggml_backend_rpc_device_init, + /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_rpc_device_supports_op, + /* .supports_buft = */ ggml_backend_rpc_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; + +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; +} + +static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; +} + +static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } +} + +static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; + } + if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { + return (void *)ggml_backend_rpc_start_server; + } + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_reg(void) { + static struct ggml_backend_reg ggml_backend_rpc_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_rpc_reg; +} + +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + if (sock == nullptr) { + GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); + return 0; + } + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; +} + +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + static std::unordered_map reg_map; + static std::mutex mutex; + static uint32_t dev_id = 0; + std::lock_guard lock(mutex); + if (reg_map.find(endpoint) != reg_map.end()) { + return reg_map[endpoint]; + } + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t ind = 0; ind < dev_count; ind++) { + std::string dev_name = "RPC" + std::to_string(dev_id); + std::string dev_desc = std::string(endpoint); + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + dev_id++; + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx + }; + reg_map[endpoint] = reg; + return reg; +} + + +GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 000000000..a72815242 --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 000000000..73b85cc53 --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); diff --git a/ggml/src/ggml-sycl/cumsum.cpp b/ggml/src/ggml-sycl/cumsum.cpp deleted file mode 100644 index c1c5fe4fe..000000000 --- a/ggml/src/ggml-sycl/cumsum.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include "cumsum.hpp" -#include "common.hpp" - -#include - -#define SYCL_CUMSUM_BLOCK_SIZE 256 - -static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) { - return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus()); -} - -static void cumsum_f32_kernel( - const float * __restrict__ src, float * __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t d1, const int64_t d2, const int64_t d3, - const sycl::nd_item<3> & item, float * smem) { - - const int tid = item.get_local_id(2); - const int block_size = item.get_local_range(2); - const int lane = tid % WARP_SIZE; - const int warp = tid / WARP_SIZE; - const int warps_per_block = block_size / WARP_SIZE; - - float * s_vals = smem; - float * s_warp_sums = smem + block_size; - float * s_carry = smem + block_size + warps_per_block; - - if (tid == 0) { - s_carry[0] = 0.0f; - } - item.barrier(sycl::access::fence_space::local_space); - - const int64_t i3 = item.get_group(0); - const int64_t i2 = item.get_group(1); - const int64_t i1 = item.get_group(2); - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { - return; - } - - const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; - float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3; - - constexpr int num_unroll = 4; - float temp[num_unroll]; - - for (int64_t i = 0; i < ne00; i += num_unroll * block_size) { - int64_t idx = i + tid * num_unroll; - - temp[0] = (idx < ne00 ? src_row[idx] : 0.0f); -#pragma unroll - for (int j = 1; j < num_unroll; j++) { - temp[j] = temp[j - 1]; - if (idx + j < ne00) { - temp[j] += src_row[idx + j]; - } - } - - float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f; - - val = warp_prefix_inclusive_sum_f32(val, item); - s_vals[tid] = val; - - if (lane == WARP_SIZE - 1) { - s_warp_sums[warp] = val; - } - item.barrier(sycl::access::fence_space::local_space); - - if (warp == 0) { - float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; - float inc = warp_prefix_inclusive_sum_f32(w, item); - if (tid < warps_per_block) { - s_warp_sums[tid] = inc - w; - } - if (tid == warps_per_block - 1) { - s_carry[1] = inc; - } - } - item.barrier(sycl::access::fence_space::local_space); - - float carry = s_carry[0]; - float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; - -#pragma unroll - for (int j = 0; j < num_unroll; j++) { - if (idx + j < ne00) { - dst_row[idx + j] = temp[j] + final_offset; - } - } - - item.barrier(sycl::access::fence_space::local_space); - - if (tid == 0) { - s_carry[0] += s_carry[1]; - } - } -} - -inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - dpct::queue_ptr stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float * src_d = static_cast(src0->data); - float * dst_d = static_cast(dst->data); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t ts = sizeof(float); - const int64_t s01 = src0->nb[1] / ts; - const int64_t s02 = src0->nb[2] / ts; - const int64_t s03 = src0->nb[3] / ts; - const int64_t d1 = dst->nb[1] / ts; - const int64_t d2 = dst->nb[2] / ts; - const int64_t d3 = dst->nb[3] / ts; - - const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; - int block_size = num_warps * WARP_SIZE; - block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE); - const int warps_per_block = block_size / WARP_SIZE; - const int smem_size = block_size + warps_per_block + 2; - - const sycl::range<3> grid(ne03, ne02, ne01); - const sycl::range<3> block(1, 1, block_size); - - stream->submit([&](sycl::handler & cgh) { - sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); - cgh.parallel_for( - sycl::nd_range<3>(grid * block, block), - [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03, - s01, s02, s03, d1, d2, d3, - item, get_pointer(smem_acc)); - }); - }); -} - -void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_cumsum(ctx, dst); -} diff --git a/ggml/src/ggml-sycl/cumsum.hpp b/ggml/src/ggml-sycl/cumsum.hpp deleted file mode 100644 index f1a564472..000000000 --- a/ggml/src/ggml-sycl/cumsum.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "common.hpp" - -void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/diag.cpp b/ggml/src/ggml-sycl/diag.cpp deleted file mode 100644 index c4264fee3..000000000 --- a/ggml/src/ggml-sycl/diag.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "diag.hpp" -#include "common.hpp" - -#define SYCL_DIAG_BLOCK_SIZE 256 - -template -static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src, - const int64_t ne0, const int64_t ne1, - const int64_t ne2, const int64_t ne3, - const int64_t total_elements, - const sycl::nd_item<1> & item) { - const int64_t i = item.get_global_id(0); - if (i >= total_elements) { - return; - } - - const int64_t i0 = i % ne0; - const int64_t i1 = (i / ne0) % ne1; - const int64_t i2 = (i / (ne0 * ne1)) % ne2; - const int64_t i3 = i / (ne0 * ne1 * ne2); - - const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; - - if (i0 == i1) { - const int64_t batch_idx = i3 * ne2 + i2; - dst[dst_idx] = src[batch_idx * ne0 + i0]; - } else { - dst[dst_idx] = T(0); - } - - (void)ne3; -} - -inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->ne[1] == 1); - - dpct::queue_ptr stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const void * src0_d = src0->data; - void * dst_d = dst->data; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - const int64_t n_elems = ggml_nelements(dst); - const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE; - - GGML_ASSERT(dst->type == GGML_TYPE_F32); - stream->parallel_for( - sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE), - [=](sycl::nd_item<1> item) { - diag_kernel(static_cast(dst_d), - static_cast(src0_d), - ne0, ne1, ne2, ne3, n_elems, item); - }); -} - -void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_diag(ctx, dst); -} diff --git a/ggml/src/ggml-sycl/diag.hpp b/ggml/src/ggml-sycl/diag.hpp deleted file mode 100644 index 20d7ce489..000000000 --- a/ggml/src/ggml-sycl/diag.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "common.hpp" - -void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/fill.cpp b/ggml/src/ggml-sycl/fill.cpp deleted file mode 100644 index 28e618e4e..000000000 --- a/ggml/src/ggml-sycl/fill.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "fill.hpp" -#include "common.hpp" - -#define SYCL_FILL_BLOCK_SIZE 256 - -template -static void fill_kernel(T * dst, const int64_t k, const T value, - const sycl::nd_item<1> & item) { - const int64_t i = (int64_t)item.get_global_id(0); - if (i >= k) { - return; - } - dst[i] = value; -} - -inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst)); - - dpct::queue_ptr stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - float value; - memcpy(&value, dst->op_params, sizeof(float)); - - const int64_t k = ggml_nelements(dst); - const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE; - void * dst_d = dst->data; - - switch (dst->type) { - case GGML_TYPE_F32: - stream->parallel_for( - sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), - [=](sycl::nd_item<1> item) { - fill_kernel(static_cast(dst_d), k, value, item); - }); - break; - case GGML_TYPE_F16: - { - sycl::half h_value = sycl::half(value); - stream->parallel_for( - sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), - [=](sycl::nd_item<1> item) { - fill_kernel(static_cast(dst_d), k, h_value, item); - }); - } - break; - default: - GGML_ABORT("unsupported type"); - } -} - -void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); - ggml_sycl_op_fill(ctx, dst); -} diff --git a/ggml/src/ggml-sycl/fill.hpp b/ggml/src/ggml-sycl/fill.hpp deleted file mode 100644 index b2adb94ff..000000000 --- a/ggml/src/ggml-sycl/fill.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "common.hpp" - -void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/solve_tri.cpp b/ggml/src/ggml-sycl/solve_tri.cpp deleted file mode 100644 index 39326deee..000000000 --- a/ggml/src/ggml-sycl/solve_tri.cpp +++ /dev/null @@ -1,172 +0,0 @@ -#include "solve_tri.hpp" -#include "common.hpp" -#include - -template -static void solve_tri_f32_fast(const float * __restrict__ A, - const float * __restrict__ B, - float * __restrict__ X, - const int64_t ne02, [[maybe_unused]] const int64_t ne03, - const int64_t nb02, const int64_t nb03, - const int64_t nb12, const int64_t nb13, - const int64_t nb2, const int64_t nb3, - const int n_arg, const int k_arg, - const sycl::nd_item<2> & item, float * sA) { - - const int n = n_template == 0 ? n_arg : n_template; - const int k = k_template == 0 ? k_arg : k_template; - - const int batch_idx = item.get_group(1); - const int lane = item.get_local_id(1) % WARP_SIZE; - const int col_idx = item.get_local_id(0); - - if (col_idx >= k) { - return; - } - - const int64_t i03 = batch_idx / ne02; - const int64_t i02 = batch_idx % ne02; - - const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); - const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); - float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); - - const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1); - -#pragma unroll - for (int i = 0; i < n * n; i += k * WARP_SIZE) { - const int i0 = i + offset; - if (i0 < n * n) { - sA[i0] = A_batch[i0]; - } - } - - item.barrier(sycl::access::fence_space::local_space); - - float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; - float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - - const int half = WARP_SIZE; - const int nrows_low = (n < half) ? n : half; - -#pragma unroll - for (int row = 0; row < nrows_low; ++row) { - float sum = 0.0f; - if (lane < row) { - sum += sA[row * n + lane] * x_low; - } - sum = warp_reduce_sum(sum); - if (lane == row) { - x_low = (x_low - sum) / sA[row * n + row]; - } - } - -#pragma unroll - for (int row = half; row < n; ++row) { - float sum = sA[row * n + lane] * x_low; - const int j = half + lane; - if (j < row) { - sum += sA[row * n + j] * x_high; - } - sum = warp_reduce_sum(sum); - if (lane == row - half) { - x_high = (x_high - sum) / sA[row * n + row]; - } - } - -#pragma unroll - for (int rr = 0; rr < 2; ++rr) { - const int row = rr * WARP_SIZE + lane; - if (row < n) { - const float val = (row < half) ? x_low : x_high; - X_batch[row * k + col_idx] = val; - } - } -} - -static void solve_tri_f32_mkl(dpct::queue_ptr stream, - const float * A, float * X, - int n, int k, - int64_t ne02, [[maybe_unused]] int64_t ne03, - int64_t nb02, [[maybe_unused]] int64_t nb03, - int64_t nb2, [[maybe_unused]] int64_t nb3) { - const float alpha = 1.0f; - const int64_t total_batches = ne02 * ne03; - if (total_batches == 0) { - return; - } - - const int64_t stride_a = nb02 / sizeof(float); - const int64_t stride_x = nb2 / sizeof(float); - - oneapi::mkl::blas::trsm_batch( - *stream, - oneapi::mkl::side::right, - oneapi::mkl::uplo::upper, - oneapi::mkl::transpose::nontrans, - oneapi::mkl::diag::nonunit, - k, n, alpha, - A, n, stride_a, - X, k, stride_x, - total_batches); -} - -inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - dpct::queue_ptr stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const int n = src0->ne[0]; - const int k = src1->ne[0]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K); - - const float * A_d = static_cast(src0->data); - const float * B_d = static_cast(src1->data); - float * X_d = static_cast(dst->data); - - if (X_d != B_d) { - const int64_t total_elements = (int64_t)n * k * ne02 * ne03; - stream->memcpy(X_d, B_d, total_elements * sizeof(float)); - } - - const int64_t nb02 = src0->nb[2]; - const int64_t nb03 = src0->nb[3]; - const int64_t nb12 = src1->nb[2]; - const int64_t nb13 = src1->nb[3]; - const int64_t nb2 = dst->nb[2]; - const int64_t nb3 = dst->nb[3]; - - const int64_t total_batches = ne02 * ne03; - - if (n <= 2 * WARP_SIZE && k <= 32) { - const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE; - const sycl::range<2> grid(1, total_batches); - const sycl::range<2> block(k, WARP_SIZE); - stream->submit([&](sycl::handler & cgh) { - sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); - cgh.parallel_for( - sycl::nd_range<2>(grid * block, block), - [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03, - nb02, nb03, nb12, nb13, nb2, nb3, - n, k, item, get_pointer(smem_acc)); - }); - }); - } else { - solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3); - } -} - -void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - ggml_sycl_op_solve_tri(ctx, dst); -} diff --git a/ggml/src/ggml-sycl/solve_tri.hpp b/ggml/src/ggml-sycl/solve_tri.hpp deleted file mode 100644 index c7c34cfa2..000000000 --- a/ggml/src/ggml-sycl/solve_tri.hpp +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "common.hpp" - -#define SYCL_SOLVE_TRI_MAX_N 64 -#define SYCL_SOLVE_TRI_MAX_K 64 - -void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ssm_scan.cpp b/ggml/src/ggml-sycl/ssm_scan.cpp deleted file mode 100644 index ae6529813..000000000 --- a/ggml/src/ggml-sycl/ssm_scan.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include "ssm_scan.hpp" -#include "common.hpp" - -template -static void ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, - const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, - const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, - const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok, - const sycl::nd_item<2> & item) { - - const int lane = item.get_local_id(1) % WARP_SIZE; - const int warp = item.get_local_id(1) / WARP_SIZE; - const int warp_idx = item.get_group(1) * c_factor + warp; - const int seq_idx = item.get_group(0); - - const int head_idx = warp_idx / d_head; - const int head_off = (warp_idx % d_head) * sizeof(float); - const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - - const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); - const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; - float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - - const int stride_x = src1_nb2 / sizeof(float); - const int stride_dt = src2_nb1 / sizeof(float); - const int stride_B = src4_nb2 / sizeof(float); - const int stride_C = src5_nb2 / sizeof(float); - const int stride_y = n_head * d_head; - - float state[c_factor]; - float state_sum = 0.0f; - -#pragma unroll - for (int j = 0; j < c_factor; j++) { - state[j] = s0_warp[WARP_SIZE * j + lane]; - } - - for (int64_t i = 0; i < n_tok; i++) { - const float dt_val = dt_warp[i * stride_dt]; - const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val); - - state_sum = 0.0f; - const float dA = sycl::exp(dt_soft_plus * A_warp[0]); - const float x_dt = x_warp[i * stride_x] * dt_soft_plus; -#pragma unroll - for (int j = 0; j < c_factor; j++) { - const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; - const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; - state[j] = (state[j] * dA) + (B_val * x_dt); - state_sum += state[j] * C_val; - } - - state_sum = warp_reduce_sum(state_sum); - - if (lane == 0) { - y_warp[i * stride_y] = state_sum; - } - } - -#pragma unroll - for (int j = 0; j < c_factor; j++) { - s_warp[WARP_SIZE * j + lane] = state[j]; - } -} - -static void ssm_scan_f32_sycl( - const float * src0, const float * src1, const float * src2, const float * src3, - const float * src4, const float * src5, const int32_t * src6, float * dst, - const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, - const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, - const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, - const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, - dpct::queue_ptr stream) { - - // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! - GGML_ASSERT(src3_nb1 == sizeof(float)); - if (d_state == 128) { - constexpr int threads = 128; - constexpr int num_warps = threads / WARP_SIZE; - const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); - const sycl::range<2> block(1, threads); - stream->parallel_for( - sycl::nd_range<2>(grid * block, block), - [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - ssm_scan_f32_group<128 / WARP_SIZE, 128>( - src0, src1, src2, src3, src4, src5, src6, dst, - src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, - src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); - }); - } else if (d_state == 256) { - constexpr int threads = 256; - constexpr int num_warps = threads / WARP_SIZE; - const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); - const sycl::range<2> block(1, threads); - stream->parallel_for( - sycl::nd_range<2>(grid * block, block), - [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - ssm_scan_f32_group<256 / WARP_SIZE, 256>( - src0, src1, src2, src3, src4, src5, src6, dst, - src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, - src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); - }); - } else { - GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)"); - } -} - -inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; - const ggml_tensor * src3 = dst->src[3]; - const ggml_tensor * src4 = dst->src[4]; - const ggml_tensor * src5 = dst->src[5]; - const ggml_tensor * src6 = dst->src[6]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src6->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const int64_t nc = src0->ne[0]; - const int64_t nr = src0->ne[1]; - const int64_t nh = src1->ne[1]; - const int64_t ng = src4->ne[1]; - const int64_t n_t = src1->ne[2]; - const int64_t n_s = src1->ne[3]; - const int64_t s_off = ggml_nelements(src1) * sizeof(float); - - GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst)); - - dpct::queue_ptr stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - ssm_scan_f32_sycl( - static_cast(src0->data), static_cast(src1->data), - static_cast(src2->data), static_cast(src3->data), - static_cast(src4->data), static_cast(src5->data), - static_cast(src6->data), static_cast(dst->data), - src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], - src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], - s_off, nc, nr, nh, ng, n_t, n_s, stream); -} - -void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); - ggml_sycl_op_ssm_scan(ctx, dst); -} diff --git a/ggml/src/ggml-sycl/ssm_scan.hpp b/ggml/src/ggml-sycl/ssm_scan.hpp deleted file mode 100644 index 1f9731fb6..000000000 --- a/ggml/src/ggml-sycl/ssm_scan.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include "common.hpp" - -void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl deleted file mode 100644 index 6ff9bcf2d..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ /dev/null @@ -1,154 +0,0 @@ -#ifdef USE_SUBGROUP_REDUCTION -enable subgroups; -#endif -enable f16; - -#define DECLARE_BYTE_LOADERS_SRC0 -#include "common_decls.tmpl" - -#include "mul_mat_vec_acc.tmpl" - -struct MulMatIdVecParams { - offset_src0: u32, - offset_src1: u32, - offset_ids: u32, - offset_dst: u32, - - k: u32, - m: u32, - n_expert: u32, - n_expert_used: u32, - b_ne1: u32, - - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, -}; - -@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] -@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens(1)] -@group(0) @binding(2) var ids: array; // [n_experd_used, n_tokens(1)] -@group(0) @binding(3) var dst: array; // [rows, n_expert_used, n_tokens(1)] - -// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 -@group(0) @binding(4) var params: MulMatIdVecParams; - -// Flattened as [row][thread] to keep each row's reduction contiguous in memory. -var partial_sums: array; - -fn partial_index(row: u32, thread: u32) -> u32 { - return row * WG_SIZE + thread; -} - -var gathered_count_ids: array; -var gathered_expert_used: array; - -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3 -#ifdef USE_SUBGROUP_REDUCTION - , @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_size) subgroup_size: u32 -#endif -) { - - let thread_id = local_id.x; - - for (var i = thread_id;i < params.n_expert;i += WG_SIZE) { - gathered_count_ids[i] = 0; - } - - workgroupBarrier(); - - // gather the selected experts for the target token. - for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) { - let expert = ids[params.offset_ids + col]; - gathered_count_ids[expert] = 1; - gathered_expert_used[expert] = col; - } - - workgroupBarrier(); - - let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - - var own_expert:u32 = 0; - var wg_in_batch:u32 = 0; - var wg_sum:u32 = 0; - - for (var i = 0u;i < params.n_expert;i += 1) { - let wg_vec_count = gathered_count_ids[i]; // 1 or 0 - let wg_per_matrix = output_groups * wg_vec_count; - if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { - own_expert = i; - wg_in_batch = wg_linear - wg_sum; - break; - } - wg_sum += wg_per_matrix; - } - - let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; - let dst1_stride = params.m; - - let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02; - let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11; - let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base; - - let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); - -#ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } - - workgroupBarrier(); - - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } -#endif - -#ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } - - workgroupBarrier(); - - var stride:u32 = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; - } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl deleted file mode 100644 index 1f59bd148..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ /dev/null @@ -1,1391 +0,0 @@ -#ifdef U32_DEQUANT_HELPERS -#define SRC0_TYPE u32 - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} - -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} -#endif - -#ifdef VEC -#define VEC_SIZE 4u -#define SRC0_TYPE vec4 -#define SRC1_TYPE vec4 - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(dot(SRC1_TYPE(src0_val), src1_val)); -} -#endif - -#ifdef SCALAR -#define VEC_SIZE 1u -#define SRC0_TYPE SRC0_INNER_TYPE -#define SRC1_TYPE SRC1_INNER_TYPE - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(src0_val) * f32(src1_val); -} -#endif - -#ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let k_vec = params.k / VEC_SIZE; - let src1_idx_base_vec = src1_idx_base / VEC_SIZE; - - // Each thread walks K, loads from the vector, and updates - // a small block of output rows held in registers. - for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q1_0 -#define BLOCK_SIZE 128 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 16 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q4_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % 4; - for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q4_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 20 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q5_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 22 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q5_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 24 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q8_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 34 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q8_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 36 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q2_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 84 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let iq = lane / 4u; - let ir = lane % 4u; - let is = ir / 2u; - - let y_offset = 128u * iq + 8u * ir + 4u * phase; - let sc0_byte = 8u * iq + is; - let sc2_byte = 8u * iq + is + 2u; - let sc4_byte = 8u * iq + is + 4u; - let sc6_byte = 8u * iq + is + 6u; - let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let dall = f32(load_f16_at_src0(block_byte_base + 80u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); - - let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); - let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); - let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); - let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); - - let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); - let qs0 = q_u32 & 0xFFFFu; - let qs1 = q_u32 >> 16u; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; - - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); - - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); - } - } - } - - return acc; -} -#endif - - -#ifdef MUL_ACC_Q3_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let ip = lane / 4u; - let il = 2u * ((lane % 4u) / 2u); - let ir = lane % 2u; - let l0 = 8u * ir; - - let q_byte = 32u + 32u * ip + l0 + 16u * phase; - let h_byte = l0 + 16u * phase; - let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - - let s_shift1 = 4u * ip; - let s_shift2 = s_shift1 + il; - - let v1 = select(64.0, 4.0, il == 0u); - let v2 = 4.0 * v1; - let shift = 2u * il; - - var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; - if (il == 0u) { - qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; - } else { - qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; - } - - let mm_idx = 2u * ip + il / 2u; - var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; - switch (mm_idx) { - case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } - case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } - case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } - default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } - } - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 108u)); - let a_base = 96u; - let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); - let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); - let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); - let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); - - var scales32 = a_4 | (a_5 << 16u); - let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; - scales32 = a_il0 | (a_il1 << 16u); - scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; - - let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); - let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); - - let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); - let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); - let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); - let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); - } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q4_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 144 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 32u * im + l0; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let scale0 = f32(sc16_0 & 0xFFu); - let scale1 = f32((sc16_0 >> 8u) & 0xFFu); - let min0 = f32(sc16_1 & 0xFFu); - let min1 = f32((sc16_1 >> 8u) & 0xFFu); - let scale2 = f32(sc16_2 & 0xFFu); - let scale3 = f32((sc16_2 >> 8u) & 0xFFu); - let min2 = f32(sc16_3 & 0xFFu); - let min3 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } - - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q5_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 176 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 48u + 32u * im + l0; - let qh_offset = 16u + 8u * ir + 4u * in; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let hm1 = 1u << (2u * im); - let hm2 = hm1 << 1u; - let hm3 = hm1 << 4u; - let hm4 = hm2 << 4u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let f0 = f32(sc16_0 & 0xFFu); - let f1 = f32((sc16_0 >> 8u) & 0xFFu); - let m0 = f32(sc16_1 & 0xFFu); - let m1 = f32((sc16_1 >> 8u) & 0xFFu); - let f4 = f32(sc16_2 & 0xFFu); - let f5 = f32((sc16_2 >> 8u) & 0xFFu); - let m4 = f32(sc16_3 & 0xFFu); - let m5 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); - let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); - - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; - - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; - - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; - } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_Q6_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 210 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; - - let num_blocks = params.k / BLOCK_SIZE; - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 208u)); - let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); - let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); - - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; - } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ1_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 50 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base)); - let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; - let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ1_M -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 56 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let sc_lo = load_u32_at_src0(block_byte_base + 48u); - let sc_hi = load_u32_at_src0(block_byte_base + 52u); - let sc0 = sc_lo & 0xFFFFu; - let sc1 = (sc_lo >> 16u) & 0xFFFFu; - let sc2 = sc_hi & 0xFFFFu; - let sc3 = (sc_hi >> 16u) & 0xFFFFu; - let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); - let d = f32(bitcast>(d_bits)[0]); - - let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), - select(sc0, sc1, sub_blk >= 2u), - sub_blk < 4u); - - let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); - let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; - let qh_lo = qh & 0xFFu; - let qh_hi = (qh >> 8u) & 0xFFu; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ2_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 66 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let ls = aux_hi >> 28u; - let db = d * (0.5 + f32(ls)) * 0.25; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ2_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 74 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(scales_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ2_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 82 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(sc_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ3_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 98 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); - let ls = aux >> 28u; - let db = d * (0.5 + f32(ls)) * 0.5; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ3_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); - let sc_word = load_u32_at_src0(block_byte_base + 106u); - let scales_byte = get_byte(sc_word, sub_blk / 2u); - let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let db = d * (1.0 + 2.0 * f32(sub_scale)); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ4_NL -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif - -#ifdef MUL_ACC_IQ4_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 136 -#define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let y_offset = sub_blk * 32u + half * 16u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let scales_h = load_u16_at_src0(block_byte_base + 2u); - let scales_l_word = load_u32_at_src0(block_byte_base + 4u); - let sl_byte = get_byte(scales_l_word, sub_blk / 2u); - let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; - let ls = i32(sl | (sh_bits << 4u)); - let dl = d * f32(ls - 32); - - let qs_byte_off = 8u + sub_blk * 16u; - let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); - let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); - let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); - let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; - } - acc[row] += row_sum; - } - } - } - - return acc; -} -#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl deleted file mode 100644 index e9ef88226..000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +++ /dev/null @@ -1,240 +0,0 @@ -#if defined(SRC_F16) || defined(DST_F16) -enable f16; -#endif - -#ifdef SRC_F16 -#define SRC_TYPE f16 -#else -#define SRC_TYPE f32 -#endif - -#ifdef DST_F16 -#define DST_TYPE f16 -#else -#define DST_TYPE f32 -#endif - -@group(0) @binding(0) -var input: array; - -@group(0) @binding(1) -var output: array; - -struct Params { - offset_i: u32, - offset_o: u32, - - // element strides - si0: u32, si1: u32, si2: u32, si3: u32, - so0: u32, so1: u32, so2: u32, so3: u32, - - src_w: u32, - src_h: u32, - src_z: u32, - src_n: u32, - - dst_w: u32, - dst_h: u32, - dst_z: u32, - dst_n: u32, - - mode_flags: u32, -}; - -@group(0) @binding(2) -var params: Params; - -const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; - -fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { - let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); - let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); - let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; - return f32(input[i]); -} - -fn cubic_weight(t: f32, a: f32) -> f32 { - let at = abs(t); - if (at <= 1.0) { - return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; - } else if (at <= 2.0) { - return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; - } else { - return 0.0; - } -} - -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(global_invocation_id) gid: vec3, - @builtin(num_workgroups) num_wg: vec3 -) { - - let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; - let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; - - if (i_out >= total) { - return; - } - - // decode (x, y, z, n) - var i = i_out; - let x_dst = i % params.dst_w; - i = i / params.dst_w; - let y_dst = i % params.dst_h; - i = i / params.dst_h; - let z_dst = i % params.dst_z; - let n_dst = i / params.dst_z; - - // scale factors - var sf0 = f32(params.dst_w) / f32(params.src_w); - var sf1 = f32(params.dst_h) / f32(params.src_h); - var sf2 = f32(params.dst_z) / f32(params.src_z); - var sf3 = f32(params.dst_n) / f32(params.src_n); - - let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; - - // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners - var pixel_offset = 0.5; - if (align_corners) { - pixel_offset = 0.0; - if (params.dst_w > 1 && params.src_w > 1) { - sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); - } - if (params.dst_h > 1 && params.src_h > 1) { - sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); - } - } - - let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); - let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); - - var result = 0.0; - -#if defined(NEAREST) - - let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); - let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); - - result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); - -#elif defined(BILINEAR) - -#if defined(ANTIALIAS) - - // Antialiased bilinear: triangle filter over a variable support region. - let support0 = max(1.0f / sf0, 1.0f); - let support1 = max(1.0f / sf1, 1.0f); - let invscale0 = 1.0 / support0; - let invscale1 = 1.0 / support1; - - let fx = (f32(x_dst) + pixel_offset) / sf0; - let fy = (f32(y_dst) + pixel_offset) / sf1; - - let x_min = max(i32(fx - support0 + pixel_offset), 0); - let y_min = max(i32(fy - support1 + pixel_offset), 0); - let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); - let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); - - var weighted_sum = 0.0; - var total_weight = 0.0; - - for (var x = x_min; x < x_max; x += 1) { - let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); - for (var y = y_min; y < y_max; y += 1) { - let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); - let w = wx * wy; - if (w > 0.0) { - weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; - total_weight += w; - } - } - } - - if (total_weight > 0.0) { - result = weighted_sum / total_weight; - } - -#else - - let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; - let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; - let x0 = i32(floor(fx)); - let y0 = i32(floor(fy)); - let dx = clamp(fx - f32(x0), 0.0, 1.0); - let dy = clamp(fy - f32(y0), 0.0, 1.0); - let a = get_clamped_input(x0, y0, z_src, n_src); - let b = get_clamped_input(x0 + 1, y0, z_src, n_src); - let c = get_clamped_input(x0, y0 + 1, z_src, n_src); - let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); - - let wa = (1.0 - dx) * (1.0 - dy); - let wb = dx * (1.0 - dy); - let wc = (1.0 - dx) * dy; - let wd = dx * dy; - - result = a * wa + b * wb + c * wc + d * wd; - -#endif - -#elif defined(BICUBIC) - - // bicubic convolution with alpha = -0.75 (PyTorch default) - let alpha = -0.75; - let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; - let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; - - let x0 = i32(floor(fx)); - let y0 = i32(floor(fy)); - let dx = fx - f32(x0); - let dy = fy - f32(y0); - - // horizontal weights for offsets -1, 0, 1, 2 - let wx0 = cubic_weight(dx + 1.0, alpha); - let wx1 = cubic_weight(dx, alpha); - let wx2 = cubic_weight(1.0 - dx, alpha); - let wx3 = cubic_weight(2.0 - dx, alpha); - - // vertical weights for offsets -1, 0, 1, 2 - let wy0 = cubic_weight(dy + 1.0, alpha); - let wy1 = cubic_weight(dy, alpha); - let wy2 = cubic_weight(1.0 - dy, alpha); - let wy3 = cubic_weight(2.0 - dy, alpha); - - // intermediate horizontal interpolation for 4x4 grid of pixels - // x0-1, x0, x0+1, x0+2, y0-1 - let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); - let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); - let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); - let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); - let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; - - // x0-1, x0, x0+1, x0+2, y0 - let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); - let q1 = get_clamped_input(x0, y0, z_src, n_src); - let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); - let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); - let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; - - // x0-1, x0, x0+1, x0+2, y0+1 - let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); - let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); - let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); - let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); - let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; - - // x0-1, x0, x0+1, x0+2, y0+2 - let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); - let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); - let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); - let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); - let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; - - // final vertical interpolation - result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; - -#endif - - let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; - output[dst_idx] = DST_TYPE(result); -} diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp new file mode 100644 index 000000000..08e680391 --- /dev/null +++ b/tools/rpc/rpc-server.cpp @@ -0,0 +1,342 @@ +#include "ggml-rpc.h" +#ifdef _WIN32 +# define NOMINMAX +# define DIRECTORY_SEPARATOR '\\' +# include +# include +# include +#else +# define DIRECTORY_SEPARATOR '/' +# include +# include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +#include +#include +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +// returns true if successful, false otherwise +static bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring wpath = utf8_to_wstring(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +static std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else if (std::getenv("HOME")) { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } else { +#if defined(__linux__) + /* no $HOME is defined, fallback to getpwuid */ + struct passwd *pw = getpwuid(getuid()); + if ((!pw) || (!pw->pw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); +#else +# error Unknown architecture +#endif + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} + +struct rpc_server_params { + std::string host = "127.0.0.1"; + int port = 50052; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); + std::vector devices; +}; + +static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { + fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); + fprintf(stderr, " -d, --device comma-separated list of devices\n"); + fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -c, --cache enable local file cache\n"); + fprintf(stderr, "\n"); +} + +static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) { + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg == "-H" || arg == "--host") { + if (++i >= argc) { + return false; + } + params.host = argv[i]; + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + return false; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads); + return false; + } + } else if (arg == "-d" || arg == "--device") { + if (++i >= argc) { + return false; + } + const std::regex regex{ R"([,/]+)" }; + std::string dev_str = argv[i]; + std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + params.devices.push_back(*iter); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str()); + return false; + } + } + } else if (arg == "-p" || arg == "--port") { + if (++i >= argc) { + return false; + } + params.port = std::stoi(argv[i]); + if (params.port <= 0 || params.port > 65535) { + return false; + } + } else if (arg == "-c" || arg == "--cache") { + params.use_cache = true; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv, params); + exit(0); + } + } + return true; +} + +static std::vector get_devices(const rpc_server_params & params) { + std::vector devices; + if (!params.devices.empty()) { + for (auto device : params.devices) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "error: unknown device: %s\n", device.c_str()); + fprintf(stderr, "available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + return {}; + } + } + } + + // Try non-CPU devices first + if (devices.empty()) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) { + devices.push_back(dev); + } + } + } + + // If there are no accelerators, fallback to CPU device + if (devices.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + devices.push_back(dev); + } + } + + return devices; +} + +int main(int argc, char * argv[]) { + std::setlocale(LC_NUMERIC, "C"); + + ggml_backend_load_all(); + + rpc_server_params params; + if (!rpc_server_params_parse(argc, argv, params)) { + fprintf(stderr, "Invalid parameters\n"); + return 1; + } + + if (params.host != "127.0.0.1") { + fprintf(stderr, "\n"); + fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); + fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str()); + fprintf(stderr, " Never expose the RPC server to an open network!\n"); + fprintf(stderr, " This is an experimental feature and is not secure!\n"); + fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); + fprintf(stderr, "\n"); + } + + auto devices = get_devices(params); + if (devices.empty()) { + fprintf(stderr, "No devices found\n"); + return 1; + } + std::string endpoint = params.host + ":" + std::to_string(params.port); + const char * cache_dir = nullptr; + std::string cache_dir_str; + if (params.use_cache) { + cache_dir_str = fs_get_cache_directory() + "rpc" + DIRECTORY_SEPARATOR; + if (!fs_create_directory_with_parents(cache_dir_str)) { + fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str()); + return 1; + } + cache_dir = cache_dir_str.c_str(); + } + + ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC"); + if (!reg) { + fprintf(stderr, "Failed to find RPC backend\n"); + return 1; + } + + auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server"); + if (!start_server_fn) { + fprintf(stderr, "Failed to obtain RPC backend start server function\n"); + return 1; + } + + start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data()); + return 0; +}