mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-17 04:09:19 +00:00
reinstate rpc files
This commit is contained in:
parent
216901034a
commit
165f6046b2
26 changed files with 3033 additions and 6165 deletions
|
|
@ -1,955 +0,0 @@
|
|||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#define HTP_GDN_MAX_SV 128
|
||||
|
||||
struct htp_gdn_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t rows_per_thread;
|
||||
size_t state_bytes;
|
||||
bool use_vtcm;
|
||||
uint8_t * vtcm_state_base;
|
||||
size_t vtcm_state_per_thread;
|
||||
};
|
||||
|
||||
static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul,
|
||||
const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm);
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul,
|
||||
const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul);
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src,
|
||||
float scale, const float * restrict dot, uint32_t n) {
|
||||
HVX_Vector acc = Q6_V_vzero();
|
||||
const HVX_Vector vscale = hvx_vec_splat_f32(scale);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vd = hvx_vmemu(dst + i * epv);
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
|
||||
hvx_vmemu(dst + i * epv) = out;
|
||||
acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vd = hvx_vmemu(dst + off);
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale));
|
||||
hvx_vec_store_u(dst + off, tail * sizeof(float), out);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot);
|
||||
acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero()));
|
||||
}
|
||||
|
||||
return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, const float * restrict mul,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float mul,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, const float * restrict src,
|
||||
const float * restrict scale, const float * restrict dot, uint32_t n,
|
||||
float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
|
||||
const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
|
||||
const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
|
||||
const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
const float * restrict mul, const float * restrict dot, uint32_t n,
|
||||
float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vm = hvx_vmem(mul + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vm = hvx_vmem(mul + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
float mul, const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
const HVX_Vector vmul = hvx_vec_splat_f32(mul);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul);
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul);
|
||||
HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul);
|
||||
HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul);
|
||||
HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul);
|
||||
HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul);
|
||||
HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul);
|
||||
HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul);
|
||||
HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul);
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1,
|
||||
float * restrict dst2, float * restrict dst3, float * restrict dst4,
|
||||
float * restrict dst5, float * restrict dst6, float * restrict dst7,
|
||||
const float * restrict src, const float * restrict scale,
|
||||
const float * restrict dot, uint32_t n, float * restrict sums) {
|
||||
HVX_Vector acc0 = Q6_V_vzero();
|
||||
HVX_Vector acc1 = Q6_V_vzero();
|
||||
HVX_Vector acc2 = Q6_V_vzero();
|
||||
HVX_Vector acc3 = Q6_V_vzero();
|
||||
HVX_Vector acc4 = Q6_V_vzero();
|
||||
HVX_Vector acc5 = Q6_V_vzero();
|
||||
HVX_Vector acc6 = Q6_V_vzero();
|
||||
HVX_Vector acc7 = Q6_V_vzero();
|
||||
const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]);
|
||||
const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]);
|
||||
const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]);
|
||||
const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]);
|
||||
const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]);
|
||||
const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]);
|
||||
const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]);
|
||||
const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]);
|
||||
|
||||
const uint32_t epv = 128 / sizeof(float);
|
||||
const uint32_t nvec = n / epv;
|
||||
const uint32_t tail = n % epv;
|
||||
for (uint32_t i = 0; i < nvec; ++i) {
|
||||
HVX_Vector vs = hvx_vmem(src + i * epv);
|
||||
HVX_Vector vdot = hvx_vmem(dot + i * epv);
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4));
|
||||
HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5));
|
||||
HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6));
|
||||
HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7));
|
||||
|
||||
hvx_vmemu(dst0 + i * epv) = out0;
|
||||
hvx_vmemu(dst1 + i * epv) = out1;
|
||||
hvx_vmemu(dst2 + i * epv) = out2;
|
||||
hvx_vmemu(dst3 + i * epv) = out3;
|
||||
hvx_vmemu(dst4 + i * epv) = out4;
|
||||
hvx_vmemu(dst5 + i * epv) = out5;
|
||||
hvx_vmemu(dst6 + i * epv) = out6;
|
||||
hvx_vmemu(dst7 + i * epv) = out7;
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot));
|
||||
}
|
||||
|
||||
if (tail) {
|
||||
const uint32_t off = nvec * epv;
|
||||
HVX_Vector vs = hvx_vmem(src + off);
|
||||
HVX_Vector vdot = hvx_vmem(dot + off);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float));
|
||||
HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0));
|
||||
HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1));
|
||||
HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2));
|
||||
HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3));
|
||||
HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4));
|
||||
HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5));
|
||||
HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6));
|
||||
HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7));
|
||||
|
||||
hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0);
|
||||
hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1);
|
||||
hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2);
|
||||
hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3);
|
||||
hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4);
|
||||
hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5);
|
||||
hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6);
|
||||
hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7);
|
||||
|
||||
acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero));
|
||||
acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero));
|
||||
acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero));
|
||||
acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero));
|
||||
acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero));
|
||||
acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero));
|
||||
acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero));
|
||||
acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero));
|
||||
}
|
||||
|
||||
HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } };
|
||||
HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } };
|
||||
hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA));
|
||||
hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB));
|
||||
}
|
||||
|
||||
static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
|
||||
struct htp_ops_context * octx = gctx->octx;
|
||||
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t rq3 = n_seqs / q->ne[3];
|
||||
const uint32_t rk3 = n_seqs / k->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
float * dst_base = (float *) (uintptr_t) dst->data;
|
||||
float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs;
|
||||
const float * state_in_base = (const float *) (uintptr_t) state->data;
|
||||
|
||||
const bool kda = (g->ne[0] == S_v);
|
||||
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_sums[4] __attribute__((aligned(128)));
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
||||
const uint32_t iq1 = iv1 % q->ne[1];
|
||||
const uint32_t ik1 = iv1 % k->ne[1];
|
||||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
memcpy(s_out, s_in, gctx->state_bytes);
|
||||
float * s_work = s_out;
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
|
||||
|
||||
for (uint32_t t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
|
||||
(uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]);
|
||||
const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
|
||||
(uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]);
|
||||
const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
|
||||
(uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]);
|
||||
const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
|
||||
(uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]);
|
||||
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
|
||||
(uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]);
|
||||
|
||||
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
|
||||
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
|
||||
|
||||
if (kda) {
|
||||
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
|
||||
|
||||
uint32_t j = 0;
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
} else {
|
||||
const float gate = expf(g_t[0]);
|
||||
uint32_t j = 0;
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
attn_data += (uint64_t) S_v * H;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_gdn_context * gctx = (struct htp_gdn_context *) data;
|
||||
struct htp_ops_context * octx = gctx->octx;
|
||||
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t rq3 = n_seqs / q->ne[3];
|
||||
const uint32_t rk3 = n_seqs / k->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
float * dst_base = (float *) (uintptr_t) dst->data;
|
||||
float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs;
|
||||
const float * state_in_base = (const float *) (uintptr_t) state->data;
|
||||
|
||||
const bool kda = (g->ne[0] == S_v);
|
||||
float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128)));
|
||||
float local_sums[8] __attribute__((aligned(128)));
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
uint8_t * spad = NULL;
|
||||
if (gctx->use_vtcm) {
|
||||
spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith;
|
||||
}
|
||||
|
||||
for (uint32_t ir = ith; ir < total_rows; ir += nth) {
|
||||
const uint32_t iv1 = ir % H;
|
||||
const uint32_t iv3 = ir / H;
|
||||
|
||||
const uint32_t iq1 = iv1 % q->ne[1];
|
||||
const uint32_t ik1 = iv1 % k->ne[1];
|
||||
const uint32_t iq3 = iv3 / rq3;
|
||||
const uint32_t ik3 = iv3 / rk3;
|
||||
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
float * s_work;
|
||||
|
||||
if (spad) {
|
||||
dma_queue_push(dma, dma_make_ptr(spad, s_in),
|
||||
S_v * sizeof(float), S_v * sizeof(float),
|
||||
S_v * sizeof(float), S_v);
|
||||
dma_queue_pop(dma);
|
||||
s_work = (float *) spad;
|
||||
} else {
|
||||
s_work = s_out;
|
||||
memcpy(s_work, s_in, gctx->state_bytes);
|
||||
}
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
|
||||
|
||||
const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data +
|
||||
(uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]);
|
||||
const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data +
|
||||
(uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]);
|
||||
const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data +
|
||||
(uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]);
|
||||
const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data +
|
||||
(uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]);
|
||||
const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data +
|
||||
(uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]);
|
||||
|
||||
memcpy(local_q, q_t, (size_t) S_v * sizeof(float));
|
||||
memcpy(local_k, k_t, (size_t) S_v * sizeof(float));
|
||||
|
||||
if (kda) {
|
||||
hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false);
|
||||
|
||||
uint32_t j = 0;
|
||||
for (; j + 8 <= S_v; j += 8) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
|
||||
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
|
||||
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
|
||||
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
|
||||
gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[8] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
} else {
|
||||
const float gate = expf(g_t[0]);
|
||||
uint32_t j = 0;
|
||||
for (; j + 8 <= S_v; j += 8) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
float * row4 = s_work + (uint64_t) (j + 4) * S_v;
|
||||
float * row5 = s_work + (uint64_t) (j + 5) * S_v;
|
||||
float * row6 = s_work + (uint64_t) (j + 6) * S_v;
|
||||
float * row7 = s_work + (uint64_t) (j + 7) * S_v;
|
||||
gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[8] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7,
|
||||
local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 8; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j + 4 <= S_v; j += 4) {
|
||||
float * row0 = s_work + (uint64_t) (j + 0) * S_v;
|
||||
float * row1 = s_work + (uint64_t) (j + 1) * S_v;
|
||||
float * row2 = s_work + (uint64_t) (j + 2) * S_v;
|
||||
float * row3 = s_work + (uint64_t) (j + 3) * S_v;
|
||||
gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums);
|
||||
float local_delta_b[4] __attribute__((aligned(128)));
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val;
|
||||
}
|
||||
gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums);
|
||||
for (uint32_t r = 0; r < 4; ++r) {
|
||||
attn_data[j + r] = local_sums[r] * scale;
|
||||
}
|
||||
}
|
||||
for (; j < S_v; ++j) {
|
||||
float * row = s_work + (uint64_t) j * S_v;
|
||||
const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v);
|
||||
const float dj = (v_t[j] - sum) * beta_val;
|
||||
attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
if (spad) {
|
||||
dma_queue_push(dma, dma_make_ptr(s_out, spad),
|
||||
S_v * sizeof(float), S_v * sizeof(float),
|
||||
S_v * sizeof(float), S_v);
|
||||
dma_queue_pop(dma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int op_gated_delta_net(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
const struct htp_tensor * g = octx->src[3];
|
||||
const struct htp_tensor * beta = octx->src[4];
|
||||
const struct htp_tensor * state = octx->src[5];
|
||||
const struct htp_tensor * dst = octx->dst;
|
||||
|
||||
if (!q || !k || !v || !g || !beta || !state || !dst) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
|
||||
if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 ||
|
||||
g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 ||
|
||||
dst->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
|
||||
if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 ||
|
||||
q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 ||
|
||||
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
struct htp_gdn_context gctx;
|
||||
gctx.octx = octx;
|
||||
gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads;
|
||||
gctx.state_bytes = (size_t) S_v * S_v * sizeof(float);
|
||||
|
||||
size_t state_aligned = (size_t) S_v * S_v * sizeof(float);
|
||||
state_aligned = (state_aligned + 127) & ~(size_t)127;
|
||||
|
||||
gctx.use_vtcm = false;
|
||||
gctx.vtcm_state_base = NULL;
|
||||
gctx.vtcm_state_per_thread = 0;
|
||||
|
||||
if (n_tokens == 1 && octx->ctx->vtcm_base) {
|
||||
size_t vtcm_total = state_aligned * octx->n_threads;
|
||||
if (octx->ctx->vtcm_size >= vtcm_total) {
|
||||
gctx.use_vtcm = true;
|
||||
gctx.vtcm_state_base = octx->ctx->vtcm_base;
|
||||
gctx.vtcm_state_per_thread = state_aligned;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_tokens == 1) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads);
|
||||
} else {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,16 +0,0 @@
|
|||
#ifndef VTCM_UTILS_H
|
||||
#define VTCM_UTILS_H
|
||||
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <hexagon_types.h>
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
uint8_t *p = *vtcm_ptr;
|
||||
*vtcm_ptr += size;
|
||||
return p;
|
||||
}
|
||||
|
||||
#endif // VTCM_UTILS_H
|
||||
|
|
@ -1,302 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
static inline half e8m0_to_fp16(uchar x) {
|
||||
ushort bits;
|
||||
bits = (ushort)(x) - (ushort)(112);
|
||||
bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10);
|
||||
return as_half(bits);
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
|
||||
kernel void kernel_gemm_moe_mxfp4_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global uchar * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
// First sub-block
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load scale for current mxfp4 block
|
||||
uint s_offset = s_sub_offset + get_global_id(0);
|
||||
float s = e8m0_to_fp32(src0_d[s_offset]);
|
||||
|
||||
// Load 16 fp4 (64-bits) in transposed layout
|
||||
uint2 mxfp4x16;
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Repeat for second sub-block
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
// Load next 16 fp4 (64-bits) in transposed layout
|
||||
mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s;
|
||||
reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s;
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load poster router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile, override correct result in the end
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
|
|
@ -1,252 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
|
||||
|
||||
#define TILESIZE_K 16
|
||||
#define TILESIZE_M 64
|
||||
#define TILESIZE_N 32
|
||||
|
||||
|
||||
#define dequantize_q4_0(q4, a_f16, scale) \
|
||||
a_f16.s0 = (half)((q4.s0 & 0x000F) - 8) * scale; \
|
||||
a_f16.s1 = (half)(((q4.s0 & 0x00F0) >> 4) - 8) * scale; \
|
||||
a_f16.s2 = (half)(((q4.s0 & 0x0F00) >> 8) - 8) * scale; \
|
||||
a_f16.s3 = (half)(((q4.s0 & 0xF000) >> 12) - 8) * scale; \
|
||||
a_f16.s4 = (half)((q4.s1 & 0x000F) - 8) * scale; \
|
||||
a_f16.s5 = (half)(((q4.s1 & 0x00F0) >> 4) - 8) * scale; \
|
||||
a_f16.s6 = (half)(((q4.s1 & 0x0F00) >> 8) - 8) * scale; \
|
||||
a_f16.s7 = (half)(((q4.s1 & 0xF000) >> 12) - 8) * scale; \
|
||||
a_f16.s8 = (half)((q4.s2 & 0x000F) - 8) * scale; \
|
||||
a_f16.s9 = (half)(((q4.s2 & 0x00F0) >> 4) - 8) * scale; \
|
||||
a_f16.sa = (half)(((q4.s2 & 0x0F00) >> 8) - 8) * scale; \
|
||||
a_f16.sb = (half)(((q4.s2 & 0xF000) >> 12) - 8) * scale; \
|
||||
a_f16.sc = (half)((q4.s3 & 0x000F) - 8) * scale; \
|
||||
a_f16.sd = (half)(((q4.s3 & 0x00F0) >> 4) - 8) * scale; \
|
||||
a_f16.se = (half)(((q4.s3 & 0x0F00) >> 8) - 8) * scale; \
|
||||
a_f16.sf = (half)(((q4.s3 & 0xF000) >> 12) - 8) * scale; \
|
||||
|
||||
|
||||
#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \
|
||||
acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \
|
||||
acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \
|
||||
acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \
|
||||
acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \
|
||||
acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \
|
||||
acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \
|
||||
acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \
|
||||
acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \
|
||||
acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \
|
||||
acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \
|
||||
acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \
|
||||
acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \
|
||||
acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \
|
||||
acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \
|
||||
acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \
|
||||
acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \
|
||||
acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \
|
||||
acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \
|
||||
acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \
|
||||
acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \
|
||||
acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \
|
||||
acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \
|
||||
acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \
|
||||
acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \
|
||||
acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \
|
||||
acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \
|
||||
acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \
|
||||
acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \
|
||||
acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \
|
||||
acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \
|
||||
acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \
|
||||
acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \
|
||||
acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \
|
||||
acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \
|
||||
acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \
|
||||
acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \
|
||||
acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \
|
||||
acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \
|
||||
acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \
|
||||
acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \
|
||||
acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \
|
||||
acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \
|
||||
acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \
|
||||
acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \
|
||||
acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \
|
||||
acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \
|
||||
acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \
|
||||
acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \
|
||||
acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \
|
||||
acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \
|
||||
acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \
|
||||
acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \
|
||||
acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \
|
||||
acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \
|
||||
acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \
|
||||
acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \
|
||||
acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \
|
||||
acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \
|
||||
acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \
|
||||
acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \
|
||||
acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \
|
||||
acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \
|
||||
acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \
|
||||
c_reg.lo += convert_float8(acc.lo); \
|
||||
c_reg.hi += convert_float8(acc.hi); \
|
||||
|
||||
|
||||
__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair
|
||||
kernel void kernel_gemm_moe_q4_0_f32_ns(
|
||||
__read_only image1d_buffer_t src0_q,
|
||||
__global half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global ushort * src2_emap,
|
||||
__write_only image1d_buffer_t dst,
|
||||
__global int * total_tiles,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
uint block_id_m = get_global_id(1); // m_tile
|
||||
uint block_id_n = get_global_id(2); // n_tile
|
||||
|
||||
// Boundary check
|
||||
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
|
||||
return;
|
||||
}
|
||||
|
||||
__private half16 reg_a;
|
||||
__private float32 reg_c = (float32)(0);
|
||||
__local half4 shared_b[128];
|
||||
|
||||
const ushort expert_id = src2_emap[block_id_n];
|
||||
|
||||
const uint row = block_id_m * TILESIZE_M;
|
||||
const uint col = block_id_n * TILESIZE_N;
|
||||
|
||||
uint sub_block_id_m = get_local_id(0);
|
||||
uint2 b_global_offset;
|
||||
b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00;
|
||||
b_global_offset.y = b_global_offset.x + (16 * ne00);
|
||||
uint2 b_local_offset;
|
||||
b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2);
|
||||
b_local_offset.y = b_local_offset.x + 16;
|
||||
|
||||
// Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks
|
||||
for (uint step = 0; step < ne00; step += TILESIZE_K * 2) {
|
||||
// First sub-block
|
||||
uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5);
|
||||
uint b_sub_offset = col * ne00 + step;
|
||||
|
||||
// Load scale for current Q4_0 block
|
||||
uint s_offset = s_sub_offset + get_global_id(0);
|
||||
half s = src0_d[s_offset];
|
||||
|
||||
// Load 16 q (64-bits) in transposed layout
|
||||
uint2 q4x16;
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
float8 bx8_f32;
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
half8 bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
dequantize_q4_0(as_ushort4(q4x16), reg_a, s);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 8 elements reduction for better precision
|
||||
half16 acc;
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
|
||||
// Repeat for second sub-block
|
||||
uint half_step = step + TILESIZE_K;
|
||||
q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3);
|
||||
b_sub_offset = col * ne00 + half_step;
|
||||
|
||||
// Load next 16 q (64-bits) in transposed layout
|
||||
q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x;
|
||||
q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x;
|
||||
|
||||
// Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements
|
||||
bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4);
|
||||
bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4);
|
||||
// Convert to half and store to LM to share within the subgroup
|
||||
bx8_f16 = convert_half8(bx8_f32);
|
||||
shared_b[b_local_offset.x] = bx8_f16.lo;
|
||||
shared_b[b_local_offset.y] = bx8_f16.hi;
|
||||
|
||||
// Dequantization
|
||||
dequantize_q4_0(as_ushort4(q4x16), reg_a, s);
|
||||
|
||||
sub_group_barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// 32 16x16 fp16 dot product with 3-levels reduction for better precision
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0);
|
||||
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
|
||||
}
|
||||
|
||||
// Load poster router and share in LM
|
||||
__local uint out_idx[TILESIZE_N];
|
||||
|
||||
if (get_local_id(0) < TILESIZE_N) {
|
||||
uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)];
|
||||
if (idx == 0xFFFFFFFF) {
|
||||
idx = src2[block_id_n * TILESIZE_N + 0];
|
||||
}
|
||||
out_idx[get_local_id(0)] = idx * ne01;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Scatter results back to original position in output grid
|
||||
uint m_offset = row + get_local_id(0);
|
||||
|
||||
write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1));
|
||||
write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2));
|
||||
write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3));
|
||||
write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4));
|
||||
write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5));
|
||||
write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6));
|
||||
write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7));
|
||||
write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8));
|
||||
write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9));
|
||||
write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa));
|
||||
write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb));
|
||||
write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc));
|
||||
write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd));
|
||||
write_imagef(dst, out_idx[14] + m_offset, (reg_c.se));
|
||||
write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf));
|
||||
write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg));
|
||||
write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh));
|
||||
write_imagef(dst, out_idx[18] + m_offset, (reg_c.si));
|
||||
write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj));
|
||||
write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk));
|
||||
write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl));
|
||||
write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm));
|
||||
write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn));
|
||||
write_imagef(dst, out_idx[24] + m_offset, (reg_c.so));
|
||||
write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp));
|
||||
write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq));
|
||||
write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr));
|
||||
write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss));
|
||||
write_imagef(dst, out_idx[29] + m_offset, (reg_c.st));
|
||||
write_imagef(dst, out_idx[30] + m_offset, (reg_c.su));
|
||||
write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv));
|
||||
|
||||
// Store zero padding parts to the index of first output in tile, override correct result in the end
|
||||
barrier(CLK_GLOBAL_MEM_FENCE);
|
||||
write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0));
|
||||
}
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_mxfp4_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ;
|
||||
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
|
||||
|
||||
regQ.s0 = src0_q[block_offset];
|
||||
regQ.s1 = src0_q[block_offset + ne01];
|
||||
regQ.s2 = src0_q[block_offset + ne01 * 2];
|
||||
regQ.s3 = src0_q[block_offset + ne01 * 3];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * convert_float4(fp16x8.lo);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * convert_float4(fp16x8.hi);
|
||||
|
||||
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_Q4_0 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline float8 q4_0_to_fp32_packed8(ushort2 q4x8) {
|
||||
float8 fp32x8;
|
||||
fp32x8.s0 = (float)((q4x8.s0 & 0x000F) - 8);
|
||||
fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) - 8);
|
||||
fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) - 8);
|
||||
fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) - 8);
|
||||
fp32x8.s4 = (float)((q4x8.s1 & 0x000F) - 8);
|
||||
fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) - 8);
|
||||
fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) - 8);
|
||||
fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) - 8);
|
||||
return fp32x8;
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_q4_0_f32_ns(
|
||||
__global uint * src0_q,
|
||||
__global half * src0_d,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_0); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ;
|
||||
uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01;
|
||||
|
||||
regQ.s0 = src0_q[block_offset];
|
||||
regQ.s1 = src0_q[block_offset + ne01];
|
||||
regQ.s2 = src0_q[block_offset + ne01 * 2];
|
||||
regQ.s3 = src0_q[block_offset + ne01 * 3];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
float8 fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
|
||||
fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
|
||||
fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * fp32x8.lo;
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * fp32x8.hi;
|
||||
|
||||
half regS = src0_d[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define QK4_0 32
|
||||
|
||||
kernel void kernel_moe_reorder_b(
|
||||
global float4 * src,
|
||||
global uint * router,
|
||||
global float4 * dst,
|
||||
global int * total_tiles,
|
||||
uint K,
|
||||
ushort map_ratio,
|
||||
uint tile_size
|
||||
) {
|
||||
uint k_4 = get_global_id(0);
|
||||
uint post_router_idx = get_global_id(1);
|
||||
|
||||
if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint router_idx = router[post_router_idx];
|
||||
|
||||
float4 out = (float4)(0);
|
||||
if (router_idx != 0xFFFFFFFF) {
|
||||
ushort activation_idx = router_idx / map_ratio;
|
||||
out = src[activation_idx * K / 4 + k_4];
|
||||
}
|
||||
|
||||
dst[post_router_idx * K / 4 + k_4] = out;
|
||||
}
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void kernel_moe_histogram(
|
||||
__global const int * input,
|
||||
__global int * hist,
|
||||
uint N,
|
||||
uint topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int expert_id = input[n * n_experts + k];
|
||||
atomic_inc(&hist[expert_id]);
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scan(
|
||||
__global int * hist,
|
||||
__global int * tile_offset,
|
||||
__global int * total_tiles,
|
||||
__global int * slot_counter,
|
||||
int tile_size,
|
||||
uint n_experts
|
||||
) {
|
||||
int offset = 0;
|
||||
for (int v = 0; v < n_experts; v++) {
|
||||
int count = hist[v];
|
||||
int tiles = (count + tile_size - 1) / tile_size;
|
||||
tile_offset[v] = offset;
|
||||
offset += tiles;
|
||||
hist[v] = 0;
|
||||
slot_counter[v] = 0;
|
||||
}
|
||||
|
||||
*total_tiles = offset;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_scatter(
|
||||
__global const int * input,
|
||||
__global int * post_router,
|
||||
__global ushort * emap,
|
||||
__global const int * tile_offset,
|
||||
__global int * slot_counter,
|
||||
int N,
|
||||
int topK,
|
||||
uint n_experts
|
||||
) {
|
||||
uint n = get_global_id(0);
|
||||
uint k = get_global_id(1);
|
||||
|
||||
if (n >= N || k >= topK) {
|
||||
return;
|
||||
}
|
||||
|
||||
int val = input[n * n_experts + k];
|
||||
|
||||
int local_slot = atomic_inc(&slot_counter[val]);
|
||||
|
||||
int tile_idx = tile_offset[val] + (local_slot / 32);
|
||||
int lane = local_slot % 32;
|
||||
int out_pos = tile_idx * 32 + lane;
|
||||
|
||||
post_router[out_pos] = n * topK + k;
|
||||
emap[tile_idx] = val;
|
||||
}
|
||||
|
||||
__kernel void kernel_moe_fill(
|
||||
__global int * post_router,
|
||||
__global int * total_tiles,
|
||||
int tile_size
|
||||
) {
|
||||
int tile_id = get_global_id(0);
|
||||
int vec_id_in_tile = get_global_id(1);
|
||||
|
||||
if (tile_id < total_tiles[0]) {
|
||||
post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF;
|
||||
}
|
||||
}
|
||||
1974
ggml/src/ggml-rpc/ggml-rpc.cpp
Normal file
1974
ggml/src/ggml-rpc/ggml-rpc.cpp
Normal file
File diff suppressed because it is too large
Load diff
683
ggml/src/ggml-rpc/transport.cpp
Normal file
683
ggml/src/ggml-rpc/transport.cpp
Normal file
|
|
@ -0,0 +1,683 @@
|
|||
#include "transport.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winsock2.h>
|
||||
#else
|
||||
# include <arpa/inet.h>
|
||||
# include <sys/socket.h>
|
||||
# include <sys/types.h>
|
||||
# include <netinet/in.h>
|
||||
# include <netinet/tcp.h>
|
||||
# include <netdb.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
|
||||
#ifdef GGML_RPC_RDMA
|
||||
# include <infiniband/verbs.h>
|
||||
# include <time.h>
|
||||
# ifndef _WIN32
|
||||
# include <poll.h>
|
||||
# endif
|
||||
#endif // GGML_RPC_RDMA
|
||||
|
||||
#ifdef _WIN32
|
||||
typedef SOCKET sockfd_t;
|
||||
using ssize_t = __int64;
|
||||
#else
|
||||
typedef int sockfd_t;
|
||||
#endif
|
||||
|
||||
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
|
||||
|
||||
#define LOG_DBG(...) \
|
||||
do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
|
||||
|
||||
#ifdef GGML_RPC_RDMA
|
||||
static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock)
|
||||
static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB
|
||||
static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes
|
||||
using rdma_gid_t = std::array<uint8_t, RDMA_GID_SIZE>;
|
||||
|
||||
struct rdma_conn {
|
||||
struct ibv_context * ctx = nullptr;
|
||||
struct ibv_pd * pd = nullptr;
|
||||
struct ibv_cq * scq = nullptr; // send completions
|
||||
struct ibv_cq * rcq = nullptr; // recv completions
|
||||
struct ibv_qp * qp = nullptr;
|
||||
|
||||
void * tx_buf = nullptr;
|
||||
struct ibv_mr * tx_mr = nullptr;
|
||||
|
||||
void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous
|
||||
struct ibv_mr * rx_mr = nullptr;
|
||||
int rx_head = 0;
|
||||
|
||||
uint32_t max_inline = 0;
|
||||
|
||||
uint8_t * rx_slot(int i) const {
|
||||
return static_cast<uint8_t *>(rx_buf) + static_cast<size_t>(i) * RDMA_CHUNK;
|
||||
}
|
||||
|
||||
bool post_rx(int i) {
|
||||
struct ibv_sge sge = {};
|
||||
sge.addr = (uintptr_t)rx_slot(i);
|
||||
sge.length = RDMA_CHUNK;
|
||||
sge.lkey = rx_mr->lkey;
|
||||
struct ibv_recv_wr wr = {}, * bad = nullptr;
|
||||
wr.wr_id = (uint64_t)i;
|
||||
wr.sg_list = &sge;
|
||||
wr.num_sge = 1;
|
||||
return ibv_post_recv(qp, &wr, &bad) == 0;
|
||||
}
|
||||
|
||||
~rdma_conn() {
|
||||
if (tx_mr) ibv_dereg_mr(tx_mr);
|
||||
if (rx_mr) ibv_dereg_mr(rx_mr);
|
||||
free(tx_buf);
|
||||
free(rx_buf);
|
||||
if (qp) ibv_destroy_qp(qp);
|
||||
if (scq) ibv_destroy_cq(scq);
|
||||
if (rcq) ibv_destroy_cq(rcq);
|
||||
if (pd) ibv_dealloc_pd(pd);
|
||||
if (ctx) ibv_close_device(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
// Local RDMA parameters captured during the probe phase and later consumed
|
||||
// by rdma_activate() after the remote side's caps arrive via HELLO.
|
||||
struct rdma_local_info {
|
||||
uint32_t qpn = 0;
|
||||
uint32_t psn = 0;
|
||||
uint8_t gid[RDMA_GID_SIZE] = {};
|
||||
uint8_t ib_port = 0;
|
||||
int gid_idx = 0;
|
||||
enum ibv_mtu path_mtu = IBV_MTU_1024;
|
||||
};
|
||||
|
||||
struct rdma_caps {
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
uint8_t gid[RDMA_GID_SIZE];
|
||||
};
|
||||
|
||||
static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size");
|
||||
|
||||
#endif // GGML_RPC_RDMA
|
||||
|
||||
struct socket_t::impl {
|
||||
impl(sockfd_t fd) : use_rdma(false), fd(fd) {}
|
||||
~impl();
|
||||
bool send_data(const void * data, size_t size);
|
||||
bool recv_data(void * data, size_t size);
|
||||
void get_caps(uint8_t * local_caps);
|
||||
void update_caps(const uint8_t * remote_caps);
|
||||
|
||||
#ifdef GGML_RPC_RDMA
|
||||
bool tcp_peer_closed();
|
||||
std::optional<rdma_gid_t> rdma_build_target_gid();
|
||||
bool rdma_probe();
|
||||
bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid);
|
||||
bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc);
|
||||
bool rdma_send(const void * data, size_t size);
|
||||
bool rdma_recv(void * data, size_t size);
|
||||
|
||||
std::unique_ptr<rdma_conn> rdma;
|
||||
rdma_local_info rdma_local = {};
|
||||
#endif // GGML_RPC_RDMA
|
||||
bool use_rdma;
|
||||
sockfd_t fd;
|
||||
};
|
||||
|
||||
socket_t::impl::~impl() {
|
||||
#ifdef GGML_RPC_RDMA
|
||||
rdma.reset();
|
||||
#endif // GGML_RPC_RDMA
|
||||
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
|
||||
#ifdef _WIN32
|
||||
if (fd != INVALID_SOCKET) closesocket(this->fd);
|
||||
#else
|
||||
if (fd >= 0) close(this->fd);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef GGML_RPC_RDMA
|
||||
|
||||
bool socket_t::impl::tcp_peer_closed() {
|
||||
if (fd < 0) return false;
|
||||
#ifndef _WIN32
|
||||
struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 };
|
||||
int r = poll(&pfd, 1, 0);
|
||||
return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP));
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address.
|
||||
// Used to match the socket's local IP against the kernel's GID table so that
|
||||
// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly:
|
||||
// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4)
|
||||
// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape)
|
||||
// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is
|
||||
// Returns std::nullopt on unsupported family or getsockname failure.
|
||||
std::optional<rdma_gid_t> socket_t::impl::rdma_build_target_gid() {
|
||||
sockaddr_storage addr = {};
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
if (getsockname(fd, reinterpret_cast<sockaddr *>(&addr), &addr_len) != 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
rdma_gid_t target = {};
|
||||
if (addr.ss_family == AF_INET) {
|
||||
const auto * a = reinterpret_cast<const sockaddr_in *>(&addr);
|
||||
target[10] = 0xff;
|
||||
target[11] = 0xff;
|
||||
memcpy(&target[12], &a->sin_addr, 4);
|
||||
return target;
|
||||
}
|
||||
if (addr.ss_family == AF_INET6) {
|
||||
const auto * a = reinterpret_cast<const sockaddr_in6 *>(&addr);
|
||||
memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE);
|
||||
return target;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool socket_t::impl::rdma_probe() {
|
||||
const char * dev_env = std::getenv("GGML_RDMA_DEV");
|
||||
const char * gid_env = std::getenv("GGML_RDMA_GID");
|
||||
|
||||
auto target_gid = rdma_build_target_gid();
|
||||
if (!target_gid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t ib_port = 1;
|
||||
int num_devs = 0;
|
||||
ibv_device ** devs = ibv_get_device_list(&num_devs);
|
||||
if (!devs || num_devs == 0) return false;
|
||||
|
||||
ibv_context * ibctx = nullptr;
|
||||
const char * matched_dev = nullptr;
|
||||
int gid_idx = gid_env ? atoi(gid_env) : -1;
|
||||
int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB
|
||||
|
||||
for (int d = 0; d < num_devs; d++) {
|
||||
const char * dn = ibv_get_device_name(devs[d]);
|
||||
if (dev_env && strcmp(dev_env, dn) != 0) continue;
|
||||
|
||||
ibv_context * ctx = ibv_open_device(devs[d]);
|
||||
if (!ctx) continue;
|
||||
|
||||
ibv_port_attr pa;
|
||||
if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; }
|
||||
|
||||
int found_gid = gid_idx;
|
||||
int found_version = IBV_GID_TYPE_IB;
|
||||
if (found_gid < 0) {
|
||||
// Find a GID on this port whose bytes equal the local TCP address
|
||||
// (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1
|
||||
// (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths
|
||||
// are avoided. ibv_query_gid_ex returns gid+type in one call.
|
||||
int v2_idx = -1;
|
||||
int v1_idx = -1;
|
||||
for (int i = 0; i < pa.gid_tbl_len; i++) {
|
||||
ibv_gid_entry entry = {};
|
||||
if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue;
|
||||
if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue;
|
||||
if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) {
|
||||
v2_idx = i;
|
||||
} else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) {
|
||||
v1_idx = i;
|
||||
}
|
||||
}
|
||||
if (v2_idx >= 0) {
|
||||
found_gid = v2_idx;
|
||||
found_version = IBV_GID_TYPE_ROCE_V2;
|
||||
} else if (v1_idx >= 0) {
|
||||
found_gid = v1_idx;
|
||||
found_version = IBV_GID_TYPE_ROCE_V1;
|
||||
}
|
||||
} else {
|
||||
// Explicit GID index from GGML_RDMA_GID — fetch its type for logging.
|
||||
ibv_gid_entry entry = {};
|
||||
if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) {
|
||||
found_version = entry.gid_type;
|
||||
}
|
||||
}
|
||||
if (found_gid >= 0) {
|
||||
ibctx = ctx;
|
||||
gid_idx = found_gid;
|
||||
gid_version = found_version;
|
||||
matched_dev = dn;
|
||||
rdma_local.path_mtu = pa.active_mtu;
|
||||
break;
|
||||
}
|
||||
ibv_close_device(ctx);
|
||||
}
|
||||
ibv_free_device_list(devs);
|
||||
if (!ibctx) return false;
|
||||
|
||||
rdma_local.ib_port = ib_port;
|
||||
rdma_local.gid_idx = gid_idx;
|
||||
|
||||
rdma = std::make_unique<rdma_conn>();
|
||||
rdma->ctx = ibctx;
|
||||
|
||||
rdma->pd = ibv_alloc_pd(ibctx);
|
||||
if (!rdma->pd) return false;
|
||||
|
||||
rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0);
|
||||
rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0);
|
||||
if (!rdma->scq || !rdma->rcq) return false;
|
||||
|
||||
ibv_qp_init_attr qia = {};
|
||||
qia.send_cq = rdma->scq;
|
||||
qia.recv_cq = rdma->rcq;
|
||||
qia.qp_type = IBV_QPT_RC;
|
||||
qia.cap.max_send_wr = 4;
|
||||
qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4;
|
||||
qia.cap.max_send_sge = 1;
|
||||
qia.cap.max_recv_sge = 1;
|
||||
qia.cap.max_inline_data = 256;
|
||||
|
||||
rdma->qp = ibv_create_qp(rdma->pd, &qia);
|
||||
if (!rdma->qp) return false;
|
||||
rdma->max_inline = qia.cap.max_inline_data;
|
||||
|
||||
rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK);
|
||||
rdma->rx_buf = aligned_alloc(4096, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK);
|
||||
if (!rdma->tx_buf || !rdma->rx_buf) return false;
|
||||
|
||||
rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE);
|
||||
rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
|
||||
if (!rdma->tx_mr || !rdma->rx_mr) return false;
|
||||
|
||||
ibv_gid local_gid;
|
||||
if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false;
|
||||
|
||||
rdma_local.qpn = rdma->qp->qp_num;
|
||||
rdma_local.psn = rdma->qp->qp_num & 0xffffff;
|
||||
memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE);
|
||||
|
||||
const char * ver_str = "";
|
||||
if (gid_version == IBV_GID_TYPE_ROCE_V2) {
|
||||
ver_str = " RoCEv2";
|
||||
} else if (gid_version == IBV_GID_TYPE_ROCE_V1) {
|
||||
ver_str = " RoCEv1";
|
||||
}
|
||||
GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n",
|
||||
matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS.
|
||||
// On success, the connection is live and ready for rdma_send/rdma_recv.
|
||||
bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) {
|
||||
// RESET -> INIT
|
||||
{
|
||||
struct ibv_qp_attr a = {};
|
||||
a.qp_state = IBV_QPS_INIT;
|
||||
a.port_num = rdma_local.ib_port;
|
||||
a.pkey_index = 0;
|
||||
a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE;
|
||||
if (ibv_modify_qp(rdma->qp, &a,
|
||||
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < RDMA_RX_DEPTH; i++) {
|
||||
if (!rdma->post_rx(i)) return false;
|
||||
}
|
||||
|
||||
// INIT -> RTR
|
||||
{
|
||||
struct ibv_qp_attr a = {};
|
||||
a.qp_state = IBV_QPS_RTR;
|
||||
a.path_mtu = rdma_local.path_mtu;
|
||||
a.dest_qp_num = remote_qpn;
|
||||
a.rq_psn = remote_psn;
|
||||
a.max_dest_rd_atomic = 1;
|
||||
a.min_rnr_timer = 1;
|
||||
a.ah_attr.is_global = 1;
|
||||
memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE);
|
||||
a.ah_attr.grh.hop_limit = 1;
|
||||
a.ah_attr.grh.sgid_index = rdma_local.gid_idx;
|
||||
a.ah_attr.dlid = 0;
|
||||
a.ah_attr.port_num = rdma_local.ib_port;
|
||||
if (ibv_modify_qp(rdma->qp, &a,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
|
||||
IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// RTR -> RTS
|
||||
{
|
||||
struct ibv_qp_attr a = {};
|
||||
a.qp_state = IBV_QPS_RTS;
|
||||
a.timeout = 14;
|
||||
a.retry_cnt = 7;
|
||||
a.rnr_retry = 7;
|
||||
a.sq_psn = rdma_local.psn;
|
||||
a.max_rd_atomic = 1;
|
||||
if (ibv_modify_qp(rdma->qp, &a,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY |
|
||||
IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n",
|
||||
rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) {
|
||||
for (uint64_t s = 0; ; s++) {
|
||||
int n = ibv_poll_cq(cq, 1, wc);
|
||||
if (n > 0) {
|
||||
if (wc->status != IBV_WC_SUCCESS) {
|
||||
GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n",
|
||||
wc->status, ibv_wc_status_str(wc->status), wc->vendor_err);
|
||||
}
|
||||
return wc->status == IBV_WC_SUCCESS;
|
||||
}
|
||||
if (n < 0) return false;
|
||||
if ((s & 0xFFFFF) == 0 && s > 0) {
|
||||
if (tcp_peer_closed()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool socket_t::impl::rdma_send(const void * data, size_t size) {
|
||||
rdma_conn * c = rdma.get();
|
||||
const uint8_t * src = (const uint8_t *)data;
|
||||
size_t rem = size;
|
||||
while (rem > 0) {
|
||||
size_t chunk = std::min(rem, RDMA_CHUNK);
|
||||
|
||||
struct ibv_sge sge = {};
|
||||
struct ibv_send_wr wr = {}, * bad = nullptr;
|
||||
wr.opcode = IBV_WR_SEND;
|
||||
wr.sg_list = &sge;
|
||||
wr.num_sge = 1;
|
||||
|
||||
if (chunk <= c->max_inline) {
|
||||
sge.addr = (uintptr_t)src;
|
||||
sge.length = chunk;
|
||||
wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;
|
||||
} else {
|
||||
memcpy(c->tx_buf, src, chunk);
|
||||
sge.addr = (uintptr_t)c->tx_buf;
|
||||
sge.length = chunk;
|
||||
sge.lkey = c->tx_mr->lkey;
|
||||
wr.send_flags = IBV_SEND_SIGNALED;
|
||||
}
|
||||
|
||||
if (ibv_post_send(c->qp, &wr, &bad) != 0) return false;
|
||||
struct ibv_wc wc;
|
||||
if (!rdma_poll(c->scq, &wc)) return false;
|
||||
|
||||
src += chunk;
|
||||
rem -= chunk;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool socket_t::impl::rdma_recv(void * data, size_t size) {
|
||||
rdma_conn * c = rdma.get();
|
||||
uint8_t * dst = (uint8_t *)data;
|
||||
size_t rem = size;
|
||||
while (rem > 0) {
|
||||
struct ibv_wc wc;
|
||||
if (!rdma_poll(c->rcq, &wc)) return false;
|
||||
|
||||
int slot = (int)wc.wr_id;
|
||||
size_t got = wc.byte_len;
|
||||
memcpy(dst, c->rx_slot(slot), got);
|
||||
|
||||
if (!c->post_rx(slot)) return false;
|
||||
|
||||
dst += got;
|
||||
rem -= got;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif // GGML_RPC_RDMA
|
||||
|
||||
bool socket_t::impl::send_data(const void * data, size_t size) {
|
||||
#ifdef GGML_RPC_RDMA
|
||||
if (use_rdma) {
|
||||
return rdma_send(data, size);
|
||||
}
|
||||
#endif
|
||||
size_t bytes_sent = 0;
|
||||
while (bytes_sent < size) {
|
||||
size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
|
||||
ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0);
|
||||
if (n < 0) {
|
||||
GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
|
||||
bytes_sent, size_to_send);
|
||||
return false;
|
||||
}
|
||||
bytes_sent += (size_t)n;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool socket_t::impl::recv_data(void * data, size_t size) {
|
||||
#ifdef GGML_RPC_RDMA
|
||||
if (use_rdma) {
|
||||
return rdma_recv(data, size);
|
||||
}
|
||||
#endif
|
||||
size_t bytes_recv = 0;
|
||||
while (bytes_recv < size) {
|
||||
size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
|
||||
ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0);
|
||||
if (n < 0) {
|
||||
GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
|
||||
bytes_recv, size_to_recv);
|
||||
return false;
|
||||
}
|
||||
if (n == 0) {
|
||||
LOG_DBG("recv returned 0 (peer closed?)\n");
|
||||
return false;
|
||||
}
|
||||
bytes_recv += (size_t)n;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void socket_t::impl::get_caps(uint8_t * local_caps) {
|
||||
memset(local_caps, 0, RPC_CONN_CAPS_SIZE);
|
||||
#ifdef GGML_RPC_RDMA
|
||||
rdma_local = {};
|
||||
if (rdma_probe()) {
|
||||
rdma_caps rc = {};
|
||||
rc.qpn = rdma_local.qpn;
|
||||
rc.psn = rdma_local.psn;
|
||||
memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE);
|
||||
memcpy(local_caps, &rc, sizeof(rc));
|
||||
} else {
|
||||
rdma.reset();
|
||||
}
|
||||
#endif // GGML_RPC_RDMA
|
||||
}
|
||||
|
||||
void socket_t::impl::update_caps(const uint8_t * remote_caps) {
|
||||
#ifdef GGML_RPC_RDMA
|
||||
if (!rdma) {
|
||||
return;
|
||||
}
|
||||
rdma_caps rc = {};
|
||||
memcpy(&rc, remote_caps, sizeof(rc));
|
||||
if (rc.qpn == 0) {
|
||||
rdma.reset();
|
||||
return;
|
||||
}
|
||||
if (rdma_activate(rc.qpn, rc.psn, rc.gid)) {
|
||||
use_rdma = true;
|
||||
} else {
|
||||
GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n");
|
||||
rdma.reset();
|
||||
}
|
||||
#else
|
||||
(void)remote_caps;
|
||||
#endif // GGML_RPC_RDMA
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
socket_t::socket_t(std::unique_ptr<impl> p) : pimpl(std::move(p)) {}
|
||||
|
||||
socket_t::~socket_t() = default;
|
||||
|
||||
bool socket_t::send_data(const void * data, size_t size) {
|
||||
return pimpl->send_data(data, size);
|
||||
}
|
||||
|
||||
bool socket_t::recv_data(void * data, size_t size) {
|
||||
return pimpl->recv_data(data, size);
|
||||
}
|
||||
|
||||
void socket_t::get_caps(uint8_t * local_caps) {
|
||||
return pimpl->get_caps(local_caps);
|
||||
}
|
||||
|
||||
void socket_t::update_caps(const uint8_t * remote_caps) {
|
||||
return pimpl->update_caps(remote_caps);
|
||||
}
|
||||
|
||||
static bool is_valid_fd(sockfd_t sockfd) {
|
||||
#ifdef _WIN32
|
||||
return sockfd != INVALID_SOCKET;
|
||||
#else
|
||||
return sockfd >= 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool set_no_delay(sockfd_t sockfd) {
|
||||
int flag = 1;
|
||||
// set TCP_NODELAY to disable Nagle's algorithm
|
||||
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
static bool set_reuse_addr(sockfd_t sockfd) {
|
||||
int flag = 1;
|
||||
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
|
||||
return ret == 0;
|
||||
}
|
||||
|
||||
socket_ptr socket_t::accept() {
|
||||
auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL);
|
||||
if (!is_valid_fd(client_socket_fd)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!set_no_delay(client_socket_fd)) {
|
||||
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
|
||||
return nullptr;
|
||||
}
|
||||
return socket_ptr(new socket_t(std::make_unique<impl>(client_socket_fd)));
|
||||
}
|
||||
|
||||
socket_ptr socket_t::create_server(const char * host, int port) {
|
||||
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (!is_valid_fd(sockfd)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!set_reuse_addr(sockfd)) {
|
||||
GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (inet_addr(host) == INADDR_NONE) {
|
||||
GGML_LOG_ERROR("Invalid host address: %s\n", host);
|
||||
return nullptr;
|
||||
}
|
||||
struct sockaddr_in serv_addr;
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = inet_addr(host);
|
||||
serv_addr.sin_port = htons(port);
|
||||
|
||||
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (listen(sockfd, 1) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return socket_ptr(new socket_t(std::make_unique<impl>(sockfd)));
|
||||
}
|
||||
|
||||
socket_ptr socket_t::connect(const char * host, int port) {
|
||||
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (!is_valid_fd(sockfd)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!set_no_delay(sockfd)) {
|
||||
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
|
||||
return nullptr;
|
||||
}
|
||||
struct sockaddr_in addr;
|
||||
addr.sin_family = AF_INET;
|
||||
addr.sin_port = htons(port);
|
||||
struct hostent * server = gethostbyname(host);
|
||||
if (server == NULL) {
|
||||
GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
|
||||
if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return socket_ptr(new socket_t(std::make_unique<impl>(sockfd)));
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::mutex g_rpc_transport_mu;
|
||||
static bool g_rpc_transport_wsa_started = false;
|
||||
#endif
|
||||
|
||||
bool rpc_transport_init() {
|
||||
#ifdef _WIN32
|
||||
std::lock_guard<std::mutex> lock(g_rpc_transport_mu);
|
||||
if (g_rpc_transport_wsa_started) {
|
||||
return true;
|
||||
}
|
||||
WSADATA wsaData;
|
||||
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||
if (res != 0) {
|
||||
return false;
|
||||
}
|
||||
g_rpc_transport_wsa_started = true;
|
||||
return true;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
}
|
||||
|
||||
void rpc_transport_shutdown() {
|
||||
#ifdef _WIN32
|
||||
std::lock_guard<std::mutex> lock(g_rpc_transport_mu);
|
||||
if (!g_rpc_transport_wsa_started) {
|
||||
return;
|
||||
}
|
||||
WSACleanup();
|
||||
g_rpc_transport_wsa_started = false;
|
||||
#endif
|
||||
}
|
||||
34
ggml/src/ggml-rpc/transport.h
Normal file
34
ggml/src/ggml-rpc/transport.h
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
struct socket_t;
|
||||
typedef std::shared_ptr<socket_t> socket_ptr;
|
||||
|
||||
static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
|
||||
static constexpr size_t RPC_CONN_CAPS_SIZE = 24;
|
||||
|
||||
struct socket_t {
|
||||
~socket_t();
|
||||
|
||||
bool send_data(const void * data, size_t size);
|
||||
bool recv_data(void * data, size_t size);
|
||||
|
||||
socket_ptr accept();
|
||||
|
||||
void get_caps(uint8_t * local_caps);
|
||||
void update_caps(const uint8_t * remote_caps);
|
||||
|
||||
static socket_ptr create_server(const char * host, int port);
|
||||
static socket_ptr connect(const char * host, int port);
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
explicit socket_t(std::unique_ptr<impl> p);
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
bool rpc_transport_init();
|
||||
void rpc_transport_shutdown();
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
#include "cumsum.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#define SYCL_CUMSUM_BLOCK_SIZE 256
|
||||
|
||||
static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) {
|
||||
return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>());
|
||||
}
|
||||
|
||||
static void cumsum_f32_kernel(
|
||||
const float * __restrict__ src, float * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t d1, const int64_t d2, const int64_t d3,
|
||||
const sycl::nd_item<3> & item, float * smem) {
|
||||
|
||||
const int tid = item.get_local_id(2);
|
||||
const int block_size = item.get_local_range(2);
|
||||
const int lane = tid % WARP_SIZE;
|
||||
const int warp = tid / WARP_SIZE;
|
||||
const int warps_per_block = block_size / WARP_SIZE;
|
||||
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + block_size;
|
||||
float * s_carry = smem + block_size + warps_per_block;
|
||||
|
||||
if (tid == 0) {
|
||||
s_carry[0] = 0.0f;
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const int64_t i3 = item.get_group(0);
|
||||
const int64_t i2 = item.get_group(1);
|
||||
const int64_t i1 = item.get_group(2);
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3;
|
||||
|
||||
constexpr int num_unroll = 4;
|
||||
float temp[num_unroll];
|
||||
|
||||
for (int64_t i = 0; i < ne00; i += num_unroll * block_size) {
|
||||
int64_t idx = i + tid * num_unroll;
|
||||
|
||||
temp[0] = (idx < ne00 ? src_row[idx] : 0.0f);
|
||||
#pragma unroll
|
||||
for (int j = 1; j < num_unroll; j++) {
|
||||
temp[j] = temp[j - 1];
|
||||
if (idx + j < ne00) {
|
||||
temp[j] += src_row[idx + j];
|
||||
}
|
||||
}
|
||||
|
||||
float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f;
|
||||
|
||||
val = warp_prefix_inclusive_sum_f32(val, item);
|
||||
s_vals[tid] = val;
|
||||
|
||||
if (lane == WARP_SIZE - 1) {
|
||||
s_warp_sums[warp] = val;
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (warp == 0) {
|
||||
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
||||
float inc = warp_prefix_inclusive_sum_f32(w, item);
|
||||
if (tid < warps_per_block) {
|
||||
s_warp_sums[tid] = inc - w;
|
||||
}
|
||||
if (tid == warps_per_block - 1) {
|
||||
s_carry[1] = inc;
|
||||
}
|
||||
}
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
float carry = s_carry[0];
|
||||
float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < num_unroll; j++) {
|
||||
if (idx + j < ne00) {
|
||||
dst_row[idx + j] = temp[j] + final_offset;
|
||||
}
|
||||
}
|
||||
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (tid == 0) {
|
||||
s_carry[0] += s_carry[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src_d = static_cast<const float *>(src0->data);
|
||||
float * dst_d = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const size_t ts = sizeof(float);
|
||||
const int64_t s01 = src0->nb[1] / ts;
|
||||
const int64_t s02 = src0->nb[2] / ts;
|
||||
const int64_t s03 = src0->nb[3] / ts;
|
||||
const int64_t d1 = dst->nb[1] / ts;
|
||||
const int64_t d2 = dst->nb[2] / ts;
|
||||
const int64_t d3 = dst->nb[3] / ts;
|
||||
|
||||
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
|
||||
int block_size = num_warps * WARP_SIZE;
|
||||
block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE);
|
||||
const int warps_per_block = block_size / WARP_SIZE;
|
||||
const int smem_size = block_size + warps_per_block + 2;
|
||||
|
||||
const sycl::range<3> grid(ne03, ne02, ne01);
|
||||
const sycl::range<3> block(1, 1, block_size);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid * block, block),
|
||||
[=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03,
|
||||
s01, s02, s03, d1, d2, d3,
|
||||
item, get_pointer(smem_acc));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_cumsum(ctx, dst);
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
#include "diag.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_DIAG_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src,
|
||||
const int64_t ne0, const int64_t ne1,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t total_elements,
|
||||
const sycl::nd_item<1> & item) {
|
||||
const int64_t i = item.get_global_id(0);
|
||||
if (i >= total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i0 = i % ne0;
|
||||
const int64_t i1 = (i / ne0) % ne1;
|
||||
const int64_t i2 = (i / (ne0 * ne1)) % ne2;
|
||||
const int64_t i3 = i / (ne0 * ne1 * ne2);
|
||||
|
||||
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
|
||||
|
||||
if (i0 == i1) {
|
||||
const int64_t batch_idx = i3 * ne2 + i2;
|
||||
dst[dst_idx] = src[batch_idx * ne0 + i0];
|
||||
} else {
|
||||
dst[dst_idx] = T(0);
|
||||
}
|
||||
|
||||
(void)ne3;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(src0->ne[1] == 1);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const void * src0_d = src0->data;
|
||||
void * dst_d = dst->data;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
const int64_t n_elems = ggml_nelements(dst);
|
||||
const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE;
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
diag_kernel(static_cast<float *>(dst_d),
|
||||
static_cast<const float *>(src0_d),
|
||||
ne0, ne1, ne2, ne3, n_elems, item);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_diag(ctx, dst);
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
#include "fill.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_FILL_BLOCK_SIZE 256
|
||||
|
||||
template <typename T>
|
||||
static void fill_kernel(T * dst, const int64_t k, const T value,
|
||||
const sycl::nd_item<1> & item) {
|
||||
const int64_t i = (int64_t)item.get_global_id(0);
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = value;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
float value;
|
||||
memcpy(&value, dst->op_params, sizeof(float));
|
||||
|
||||
const int64_t k = ggml_nelements(dst);
|
||||
const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE;
|
||||
void * dst_d = dst->data;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
fill_kernel(static_cast<float *>(dst_d), k, value, item);
|
||||
});
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
sycl::half h_value = sycl::half(value);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE),
|
||||
[=](sycl::nd_item<1> item) {
|
||||
fill_kernel(static_cast<sycl::half *>(dst_d), k, h_value, item);
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
|
||||
ggml_sycl_op_fill(ctx, dst);
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
#include "solve_tri.hpp"
|
||||
#include "common.hpp"
|
||||
#include <oneapi/mkl/blas.hpp>
|
||||
|
||||
template <int n_template, int k_template>
|
||||
static void solve_tri_f32_fast(const float * __restrict__ A,
|
||||
const float * __restrict__ B,
|
||||
float * __restrict__ X,
|
||||
const int64_t ne02, [[maybe_unused]] const int64_t ne03,
|
||||
const int64_t nb02, const int64_t nb03,
|
||||
const int64_t nb12, const int64_t nb13,
|
||||
const int64_t nb2, const int64_t nb3,
|
||||
const int n_arg, const int k_arg,
|
||||
const sycl::nd_item<2> & item, float * sA) {
|
||||
|
||||
const int n = n_template == 0 ? n_arg : n_template;
|
||||
const int k = k_template == 0 ? k_arg : k_template;
|
||||
|
||||
const int batch_idx = item.get_group(1);
|
||||
const int lane = item.get_local_id(1) % WARP_SIZE;
|
||||
const int col_idx = item.get_local_id(0);
|
||||
|
||||
if (col_idx >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = batch_idx / ne02;
|
||||
const int64_t i02 = batch_idx % ne02;
|
||||
|
||||
const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
|
||||
const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
|
||||
float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3);
|
||||
|
||||
const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||||
const int i0 = i + offset;
|
||||
if (i0 < n * n) {
|
||||
sA[i0] = A_batch[i0];
|
||||
}
|
||||
}
|
||||
|
||||
item.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
||||
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
||||
|
||||
const int half = WARP_SIZE;
|
||||
const int nrows_low = (n < half) ? n : half;
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < nrows_low; ++row) {
|
||||
float sum = 0.0f;
|
||||
if (lane < row) {
|
||||
sum += sA[row * n + lane] * x_low;
|
||||
}
|
||||
sum = warp_reduce_sum<WARP_SIZE>(sum);
|
||||
if (lane == row) {
|
||||
x_low = (x_low - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = half; row < n; ++row) {
|
||||
float sum = sA[row * n + lane] * x_low;
|
||||
const int j = half + lane;
|
||||
if (j < row) {
|
||||
sum += sA[row * n + j] * x_high;
|
||||
}
|
||||
sum = warp_reduce_sum<WARP_SIZE>(sum);
|
||||
if (lane == row - half) {
|
||||
x_high = (x_high - sum) / sA[row * n + row];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int rr = 0; rr < 2; ++rr) {
|
||||
const int row = rr * WARP_SIZE + lane;
|
||||
if (row < n) {
|
||||
const float val = (row < half) ? x_low : x_high;
|
||||
X_batch[row * k + col_idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void solve_tri_f32_mkl(dpct::queue_ptr stream,
|
||||
const float * A, float * X,
|
||||
int n, int k,
|
||||
int64_t ne02, [[maybe_unused]] int64_t ne03,
|
||||
int64_t nb02, [[maybe_unused]] int64_t nb03,
|
||||
int64_t nb2, [[maybe_unused]] int64_t nb3) {
|
||||
const float alpha = 1.0f;
|
||||
const int64_t total_batches = ne02 * ne03;
|
||||
if (total_batches == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t stride_a = nb02 / sizeof(float);
|
||||
const int64_t stride_x = nb2 / sizeof(float);
|
||||
|
||||
oneapi::mkl::blas::trsm_batch(
|
||||
*stream,
|
||||
oneapi::mkl::side::right,
|
||||
oneapi::mkl::uplo::upper,
|
||||
oneapi::mkl::transpose::nontrans,
|
||||
oneapi::mkl::diag::nonunit,
|
||||
k, n, alpha,
|
||||
A, n, stride_a,
|
||||
X, k, stride_x,
|
||||
total_batches);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const int n = src0->ne[0];
|
||||
const int k = src1->ne[0];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K);
|
||||
|
||||
const float * A_d = static_cast<const float *>(src0->data);
|
||||
const float * B_d = static_cast<const float *>(src1->data);
|
||||
float * X_d = static_cast<float *>(dst->data);
|
||||
|
||||
if (X_d != B_d) {
|
||||
const int64_t total_elements = (int64_t)n * k * ne02 * ne03;
|
||||
stream->memcpy(X_d, B_d, total_elements * sizeof(float));
|
||||
}
|
||||
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
const int64_t nb03 = src0->nb[3];
|
||||
const int64_t nb12 = src1->nb[2];
|
||||
const int64_t nb13 = src1->nb[3];
|
||||
const int64_t nb2 = dst->nb[2];
|
||||
const int64_t nb3 = dst->nb[3];
|
||||
|
||||
const int64_t total_batches = ne02 * ne03;
|
||||
|
||||
if (n <= 2 * WARP_SIZE && k <= 32) {
|
||||
const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE;
|
||||
const sycl::range<2> grid(1, total_batches);
|
||||
const sycl::range<2> block(k, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03,
|
||||
nb02, nb03, nb12, nb13, nb2, nb3,
|
||||
n, k, item, get_pointer(smem_acc));
|
||||
});
|
||||
});
|
||||
} else {
|
||||
solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_solve_tri(ctx, dst);
|
||||
}
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_SOLVE_TRI_MAX_N 64
|
||||
#define SYCL_SOLVE_TRI_MAX_K 64
|
||||
|
||||
void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
#include "ssm_scan.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
template <int c_factor, int d_state>
|
||||
static void ssm_scan_f32_group(
|
||||
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok,
|
||||
const sycl::nd_item<2> & item) {
|
||||
|
||||
const int lane = item.get_local_id(1) % WARP_SIZE;
|
||||
const int warp = item.get_local_id(1) / WARP_SIZE;
|
||||
const int warp_idx = item.get_group(1) * c_factor + warp;
|
||||
const int seq_idx = item.get_group(0);
|
||||
|
||||
const int head_idx = warp_idx / d_head;
|
||||
const int head_off = (warp_idx % d_head) * sizeof(float);
|
||||
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
|
||||
|
||||
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
|
||||
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
|
||||
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
const int stride_dt = src2_nb1 / sizeof(float);
|
||||
const int stride_B = src4_nb2 / sizeof(float);
|
||||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_y = n_head * d_head;
|
||||
|
||||
float state[c_factor];
|
||||
float state_sum = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
state[j] = s0_warp[WARP_SIZE * j + lane];
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < n_tok; i++) {
|
||||
const float dt_val = dt_warp[i * stride_dt];
|
||||
const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val);
|
||||
|
||||
state_sum = 0.0f;
|
||||
const float dA = sycl::exp(dt_soft_plus * A_warp[0]);
|
||||
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
|
||||
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
state_sum += state[j] * C_val;
|
||||
}
|
||||
|
||||
state_sum = warp_reduce_sum<WARP_SIZE>(state_sum);
|
||||
|
||||
if (lane == 0) {
|
||||
y_warp[i * stride_y] = state_sum;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < c_factor; j++) {
|
||||
s_warp[WARP_SIZE * j + lane] = state[j];
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_scan_f32_sycl(
|
||||
const float * src0, const float * src1, const float * src2, const float * src3,
|
||||
const float * src4, const float * src5, const int32_t * src6, float * dst,
|
||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
|
||||
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
|
||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
dpct::queue_ptr stream) {
|
||||
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
GGML_ASSERT(src3_nb1 == sizeof(float));
|
||||
if (d_state == 128) {
|
||||
constexpr int threads = 128;
|
||||
constexpr int num_warps = threads / WARP_SIZE;
|
||||
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
|
||||
const sycl::range<2> block(1, threads);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
ssm_scan_f32_group<128 / WARP_SIZE, 128>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
|
||||
});
|
||||
} else if (d_state == 256) {
|
||||
constexpr int threads = 256;
|
||||
constexpr int num_warps = threads / WARP_SIZE;
|
||||
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
|
||||
const sycl::range<2> block(1, threads);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<2>(grid * block, block),
|
||||
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
ssm_scan_f32_group<256 / WARP_SIZE, 256>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
|
||||
});
|
||||
} else {
|
||||
GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)");
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const ggml_tensor * src3 = dst->src[3];
|
||||
const ggml_tensor * src4 = dst->src[4];
|
||||
const ggml_tensor * src5 = dst->src[5];
|
||||
const ggml_tensor * src6 = dst->src[6];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nr = src0->ne[1];
|
||||
const int64_t nh = src1->ne[1];
|
||||
const int64_t ng = src4->ne[1];
|
||||
const int64_t n_t = src1->ne[2];
|
||||
const int64_t n_s = src1->ne[3];
|
||||
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst));
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
ssm_scan_f32_sycl(
|
||||
static_cast<const float *>(src0->data), static_cast<const float *>(src1->data),
|
||||
static_cast<const float *>(src2->data), static_cast<const float *>(src3->data),
|
||||
static_cast<const float *>(src4->data), static_cast<const float *>(src5->data),
|
||||
static_cast<const int32_t *>(src6->data), static_cast<float *>(dst->data),
|
||||
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
|
||||
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
|
||||
s_off, nc, nr, nh, ng, n_t, n_s, stream);
|
||||
}
|
||||
|
||||
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7);
|
||||
ggml_sycl_op_ssm_scan(ctx, dst);
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
enable subgroups;
|
||||
#endif
|
||||
enable f16;
|
||||
|
||||
#define DECLARE_BYTE_LOADERS_SRC0
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#include "mul_mat_vec_acc.tmpl"
|
||||
|
||||
struct MulMatIdVecParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_ids: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
k: u32,
|
||||
m: u32,
|
||||
n_expert: u32,
|
||||
n_expert_used: u32,
|
||||
b_ne1: u32,
|
||||
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // [cols, rows, n_expert]
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // [cols, b_ne1, n_tokens(1)]
|
||||
@group(0) @binding(2) var<storage, read_write> ids: array<u32>; // [n_experd_used, n_tokens(1)]
|
||||
@group(0) @binding(3) var<storage, read_write> dst: array<f32>; // [rows, n_expert_used, n_tokens(1)]
|
||||
|
||||
// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01
|
||||
@group(0) @binding(4) var<uniform> params: MulMatIdVecParams;
|
||||
|
||||
// Flattened as [row][thread] to keep each row's reduction contiguous in memory.
|
||||
var<workgroup> partial_sums: array<f32, OUTPUTS_PER_WG * WG_SIZE>;
|
||||
|
||||
fn partial_index(row: u32, thread: u32) -> u32 {
|
||||
return row * WG_SIZE + thread;
|
||||
}
|
||||
|
||||
var<workgroup> gathered_count_ids: array<u32, N_EXPERTS>;
|
||||
var<workgroup> gathered_expert_used: array<u32, N_EXPERTS>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
, @builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32
|
||||
#endif
|
||||
) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
|
||||
for (var i = thread_id;i < params.n_expert;i += WG_SIZE) {
|
||||
gathered_count_ids[i] = 0;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// gather the selected experts for the target token.
|
||||
for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) {
|
||||
let expert = ids[params.offset_ids + col];
|
||||
gathered_count_ids[expert] = 1;
|
||||
gathered_expert_used[expert] = col;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
var own_expert:u32 = 0;
|
||||
var wg_in_batch:u32 = 0;
|
||||
var wg_sum:u32 = 0;
|
||||
|
||||
for (var i = 0u;i < params.n_expert;i += 1) {
|
||||
let wg_vec_count = gathered_count_ids[i]; // 1 or 0
|
||||
let wg_per_matrix = output_groups * wg_vec_count;
|
||||
if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) {
|
||||
own_expert = i;
|
||||
wg_in_batch = wg_linear - wg_sum;
|
||||
break;
|
||||
}
|
||||
wg_sum += wg_per_matrix;
|
||||
}
|
||||
|
||||
let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG;
|
||||
let dst1_stride = params.m;
|
||||
|
||||
let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02;
|
||||
let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11;
|
||||
let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base;
|
||||
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride:u32 = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,240 +0,0 @@
|
|||
#if defined(SRC_F16) || defined(DST_F16)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
#ifdef SRC_F16
|
||||
#define SRC_TYPE f16
|
||||
#else
|
||||
#define SRC_TYPE f32
|
||||
#endif
|
||||
|
||||
#ifdef DST_F16
|
||||
#define DST_TYPE f16
|
||||
#else
|
||||
#define DST_TYPE f32
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> input: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> output: array<DST_TYPE>;
|
||||
|
||||
struct Params {
|
||||
offset_i: u32,
|
||||
offset_o: u32,
|
||||
|
||||
// element strides
|
||||
si0: u32, si1: u32, si2: u32, si3: u32,
|
||||
so0: u32, so1: u32, so2: u32, so3: u32,
|
||||
|
||||
src_w: u32,
|
||||
src_h: u32,
|
||||
src_z: u32,
|
||||
src_n: u32,
|
||||
|
||||
dst_w: u32,
|
||||
dst_h: u32,
|
||||
dst_z: u32,
|
||||
dst_n: u32,
|
||||
|
||||
mode_flags: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u;
|
||||
|
||||
fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 {
|
||||
let cx = u32(clamp(x, 0, i32(params.src_w) - 1));
|
||||
let cy = u32(clamp(y, 0, i32(params.src_h) - 1));
|
||||
let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3;
|
||||
return f32(input[i]);
|
||||
}
|
||||
|
||||
fn cubic_weight(t: f32, a: f32) -> f32 {
|
||||
let at = abs(t);
|
||||
if (at <= 1.0) {
|
||||
return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0;
|
||||
} else if (at <= 2.0) {
|
||||
return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a;
|
||||
} else {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
|
||||
let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y;
|
||||
let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n;
|
||||
|
||||
if (i_out >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
// decode (x, y, z, n)
|
||||
var i = i_out;
|
||||
let x_dst = i % params.dst_w;
|
||||
i = i / params.dst_w;
|
||||
let y_dst = i % params.dst_h;
|
||||
i = i / params.dst_h;
|
||||
let z_dst = i % params.dst_z;
|
||||
let n_dst = i / params.dst_z;
|
||||
|
||||
// scale factors
|
||||
var sf0 = f32(params.dst_w) / f32(params.src_w);
|
||||
var sf1 = f32(params.dst_h) / f32(params.src_h);
|
||||
var sf2 = f32(params.dst_z) / f32(params.src_z);
|
||||
var sf3 = f32(params.dst_n) / f32(params.src_n);
|
||||
|
||||
let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0;
|
||||
|
||||
// pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners
|
||||
var pixel_offset = 0.5;
|
||||
if (align_corners) {
|
||||
pixel_offset = 0.0;
|
||||
if (params.dst_w > 1 && params.src_w > 1) {
|
||||
sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1);
|
||||
}
|
||||
if (params.dst_h > 1 && params.src_h > 1) {
|
||||
sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1);
|
||||
}
|
||||
}
|
||||
|
||||
let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2)));
|
||||
let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3)));
|
||||
|
||||
var result = 0.0;
|
||||
|
||||
#if defined(NEAREST)
|
||||
|
||||
let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0)));
|
||||
let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1)));
|
||||
|
||||
result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src);
|
||||
|
||||
#elif defined(BILINEAR)
|
||||
|
||||
#if defined(ANTIALIAS)
|
||||
|
||||
// Antialiased bilinear: triangle filter over a variable support region.
|
||||
let support0 = max(1.0f / sf0, 1.0f);
|
||||
let support1 = max(1.0f / sf1, 1.0f);
|
||||
let invscale0 = 1.0 / support0;
|
||||
let invscale1 = 1.0 / support1;
|
||||
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1;
|
||||
|
||||
let x_min = max(i32(fx - support0 + pixel_offset), 0);
|
||||
let y_min = max(i32(fy - support1 + pixel_offset), 0);
|
||||
let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w));
|
||||
let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h));
|
||||
|
||||
var weighted_sum = 0.0;
|
||||
var total_weight = 0.0;
|
||||
|
||||
for (var x = x_min; x < x_max; x += 1) {
|
||||
let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0);
|
||||
for (var y = y_min; y < y_max; y += 1) {
|
||||
let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0);
|
||||
let w = wx * wy;
|
||||
if (w > 0.0) {
|
||||
weighted_sum += get_clamped_input(x, y, z_src, n_src) * w;
|
||||
total_weight += w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (total_weight > 0.0) {
|
||||
result = weighted_sum / total_weight;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
|
||||
let x0 = i32(floor(fx));
|
||||
let y0 = i32(floor(fy));
|
||||
let dx = clamp(fx - f32(x0), 0.0, 1.0);
|
||||
let dy = clamp(fy - f32(y0), 0.0, 1.0);
|
||||
let a = get_clamped_input(x0, y0, z_src, n_src);
|
||||
let b = get_clamped_input(x0 + 1, y0, z_src, n_src);
|
||||
let c = get_clamped_input(x0, y0 + 1, z_src, n_src);
|
||||
let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
|
||||
|
||||
let wa = (1.0 - dx) * (1.0 - dy);
|
||||
let wb = dx * (1.0 - dy);
|
||||
let wc = (1.0 - dx) * dy;
|
||||
let wd = dx * dy;
|
||||
|
||||
result = a * wa + b * wb + c * wc + d * wd;
|
||||
|
||||
#endif
|
||||
|
||||
#elif defined(BICUBIC)
|
||||
|
||||
// bicubic convolution with alpha = -0.75 (PyTorch default)
|
||||
let alpha = -0.75;
|
||||
let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset;
|
||||
let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset;
|
||||
|
||||
let x0 = i32(floor(fx));
|
||||
let y0 = i32(floor(fy));
|
||||
let dx = fx - f32(x0);
|
||||
let dy = fy - f32(y0);
|
||||
|
||||
// horizontal weights for offsets -1, 0, 1, 2
|
||||
let wx0 = cubic_weight(dx + 1.0, alpha);
|
||||
let wx1 = cubic_weight(dx, alpha);
|
||||
let wx2 = cubic_weight(1.0 - dx, alpha);
|
||||
let wx3 = cubic_weight(2.0 - dx, alpha);
|
||||
|
||||
// vertical weights for offsets -1, 0, 1, 2
|
||||
let wy0 = cubic_weight(dy + 1.0, alpha);
|
||||
let wy1 = cubic_weight(dy, alpha);
|
||||
let wy2 = cubic_weight(1.0 - dy, alpha);
|
||||
let wy3 = cubic_weight(2.0 - dy, alpha);
|
||||
|
||||
// intermediate horizontal interpolation for 4x4 grid of pixels
|
||||
// x0-1, x0, x0+1, x0+2, y0-1
|
||||
let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src);
|
||||
let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src);
|
||||
let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src);
|
||||
let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src);
|
||||
let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0
|
||||
let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src);
|
||||
let q1 = get_clamped_input(x0, y0, z_src, n_src);
|
||||
let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src);
|
||||
let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src);
|
||||
let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0+1
|
||||
let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src);
|
||||
let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src);
|
||||
let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src);
|
||||
let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src);
|
||||
let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3;
|
||||
|
||||
// x0-1, x0, x0+1, x0+2, y0+2
|
||||
let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src);
|
||||
let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src);
|
||||
let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src);
|
||||
let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src);
|
||||
let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3;
|
||||
|
||||
// final vertical interpolation
|
||||
result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3;
|
||||
|
||||
#endif
|
||||
|
||||
let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3;
|
||||
output[dst_idx] = DST_TYPE(result);
|
||||
}
|
||||
342
tools/rpc/rpc-server.cpp
Normal file
342
tools/rpc/rpc-server.cpp
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
#include "ggml-rpc.h"
|
||||
#ifdef _WIN32
|
||||
# define NOMINMAX
|
||||
# define DIRECTORY_SEPARATOR '\\'
|
||||
# include <windows.h>
|
||||
# include <fcntl.h>
|
||||
# include <io.h>
|
||||
#else
|
||||
# define DIRECTORY_SEPARATOR '/'
|
||||
# include <unistd.h>
|
||||
# include <sys/stat.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <clocale>
|
||||
#include <codecvt>
|
||||
#include <filesystem>
|
||||
#include <regex>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/types.h>
|
||||
#include <pwd.h>
|
||||
#endif
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
#ifdef _WIN32
|
||||
static std::wstring utf8_to_wstring(const std::string & str) {
|
||||
if (str.empty()) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
|
||||
if (size <= 0) {
|
||||
return std::wstring();
|
||||
}
|
||||
|
||||
std::wstring wstr(size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
|
||||
|
||||
return wstr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
// returns true if successful, false otherwise
|
||||
static bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring wpath = utf8_to_wstring(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t pos_slash = 0;
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||
|
||||
pos_slash += 1;
|
||||
|
||||
// skip the drive letter, in some systems it can return an access denied error
|
||||
if (subpath.length() == 2 && subpath[1] == ':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool success = CreateDirectoryW(subpath.c_str(), NULL);
|
||||
|
||||
if (!success) {
|
||||
const DWORD error = GetLastError();
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (error == ERROR_ALREADY_EXISTS) {
|
||||
const DWORD attributes = GetFileAttributesW(subpath.c_str());
|
||||
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
// if the path already exists, check whether it's a directory
|
||||
struct stat info;
|
||||
if (stat(path.c_str(), &info) == 0) {
|
||||
return S_ISDIR(info.st_mode);
|
||||
}
|
||||
|
||||
size_t pos_slash = 1; // skip leading slashes for directory creation
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
|
||||
const std::string subpath = path.substr(0, pos_slash);
|
||||
struct stat info;
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (stat(subpath.c_str(), &info) == 0) {
|
||||
if (!S_ISDIR(info.st_mode)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// create parent directories
|
||||
const int ret = mkdir(subpath.c_str(), 0755);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
static std::string fs_get_cache_directory() {
|
||||
std::string cache_directory = "";
|
||||
auto ensure_trailing_slash = [](std::string p) {
|
||||
// Make sure to add trailing slash
|
||||
if (p.back() != DIRECTORY_SEPARATOR) {
|
||||
p += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
return p;
|
||||
};
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
cache_directory = std::getenv("HOME") + std::string("/.cache/");
|
||||
} else {
|
||||
#if defined(__linux__)
|
||||
/* no $HOME is defined, fallback to getpwuid */
|
||||
struct passwd *pw = getpwuid(getuid());
|
||||
if ((!pw) || (!pw->pw_dir)) {
|
||||
throw std::runtime_error("Failed to find $HOME directory");
|
||||
}
|
||||
|
||||
cache_directory = std::string(pw->pw_dir) + std::string("/.cache/");
|
||||
#else /* defined(__linux__) */
|
||||
throw std::runtime_error("Failed to find $HOME directory");
|
||||
#endif /* defined(__linux__) */
|
||||
}
|
||||
#elif defined(__APPLE__)
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#elif defined(__EMSCRIPTEN__)
|
||||
GGML_ABORT("not implemented on this platform");
|
||||
#else
|
||||
# error Unknown architecture
|
||||
#endif
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
return ensure_trailing_slash(cache_directory);
|
||||
}
|
||||
|
||||
struct rpc_server_params {
|
||||
std::string host = "127.0.0.1";
|
||||
int port = 50052;
|
||||
bool use_cache = false;
|
||||
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
|
||||
std::vector<std::string> devices;
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
||||
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) {
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
if (arg == "-H" || arg == "--host") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.host = argv[i];
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.n_threads = std::stoi(argv[i]);
|
||||
if (params.n_threads <= 0) {
|
||||
fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "-d" || arg == "--device") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::string dev_str = argv[i];
|
||||
std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
|
||||
std::sregex_token_iterator end;
|
||||
for ( ; iter != end; ++iter) {
|
||||
try {
|
||||
params.devices.push_back(*iter);
|
||||
} catch (const std::exception & ) {
|
||||
fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (arg == "-p" || arg == "--port") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
}
|
||||
params.port = std::stoi(argv[i]);
|
||||
if (params.port <= 0 || params.port > 65535) {
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "-c" || arg == "--cache") {
|
||||
params.use_cache = true;
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
if (!params.devices.empty()) {
|
||||
for (auto device : params.devices) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown device: %s\n", device.c_str());
|
||||
fprintf(stderr, "available devices:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try non-CPU devices first
|
||||
if (devices.empty()) {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no accelerators, fallback to CPU device
|
||||
if (devices.empty()) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
ggml_backend_load_all();
|
||||
|
||||
rpc_server_params params;
|
||||
if (!rpc_server_params_parse(argc, argv, params)) {
|
||||
fprintf(stderr, "Invalid parameters\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.host != "127.0.0.1") {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
|
||||
fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
|
||||
fprintf(stderr, " Never expose the RPC server to an open network!\n");
|
||||
fprintf(stderr, " This is an experimental feature and is not secure!\n");
|
||||
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
auto devices = get_devices(params);
|
||||
if (devices.empty()) {
|
||||
fprintf(stderr, "No devices found\n");
|
||||
return 1;
|
||||
}
|
||||
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
if (params.use_cache) {
|
||||
cache_dir_str = fs_get_cache_directory() + "rpc" + DIRECTORY_SEPARATOR;
|
||||
if (!fs_create_directory_with_parents(cache_dir_str)) {
|
||||
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
|
||||
return 1;
|
||||
}
|
||||
cache_dir = cache_dir_str.c_str();
|
||||
}
|
||||
|
||||
ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!reg) {
|
||||
fprintf(stderr, "Failed to find RPC backend\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server");
|
||||
if (!start_server_fn) {
|
||||
fprintf(stderr, "Failed to obtain RPC backend start server function\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
|
||||
return 0;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue