Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	examples/model-conversion/Makefile
#	examples/model-conversion/README.md
#	examples/model-conversion/scripts/utils/quantize.sh
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/group_norm.cl
#	ggml/src/ggml-opencl/kernels/norm.cl
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	tests/test-backend-ops.cpp
#	tests/test-opt.cpp
#	tools/batched-bench/batched-bench.cpp
#	tools/mtmd/CMakeLists.txt
This commit is contained in:
Concedo 2025-08-27 17:39:24 +08:00
commit 654b9eee73
11 changed files with 595 additions and 178 deletions

View file

@ -1108,7 +1108,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
printf("\"\n\n");
printf(" case \"$prev\" in\n");
printf(" --model)\n");
printf(" --model|-m)\n");
printf(" COMPREPLY=( $(compgen -f -X '!*.gguf' -- \"$cur\") $(compgen -d -- \"$cur\") )\n");
printf(" return 0\n");
printf(" ;;\n");

View file

@ -2169,94 +2169,117 @@ class tinyBLAS_Q0_PPC {
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
const float *A, int64_t lda,
const float *B, int64_t ldb,
float *C, int64_t ldc,
const float * A, int64_t lda,
const float * B, int64_t ldb,
float * C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
matmul_tiled(m, n, mc, nc, kc);
} else {
mnpack(0, m, 0, n);
}
}
private:
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
inline void vector_permute_store_4(vector float *src, float *vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergel(src[0], src[1]);
t4 = vec_mergel(src[2], src[3]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
}
inline void vector_permute_store_8(vector float *src, float *vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergeh(src[4], src[5]);
t4 = vec_mergeh(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
t1 = vec_mergel(src[0], src[1]);
t2 = vec_mergel(src[2], src[3]);
t3 = vec_mergel(src[4], src[5]);
t4 = vec_mergel(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset + 16);
vec_xst(t6, 0, vecOffset + 20);
vec_xst(t7, 0, vecOffset + 24);
vec_xst(t8, 0, vecOffset + 28);
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
vec_t vec_C[4];
__builtin_mma_disassemble_acc(vec_C, ACC);
for (int I = 0; I < 4; I++) {
for (int J = 0; J < 4; J++) {
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
*c_ptr += *((float *)&vec_C[I]+J);
}
}
}
inline void vector_permute_store_4(vector float * src, float * vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergel(src[0], src[1]);
t4 = vec_mergel(src[2], src[3]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
}
inline void vector_permute_store_8(vector float * src, float * vecOffset) {
vector float t1, t2, t3, t4, t5, t6, t7, t8;
t1 = vec_mergeh(src[0], src[1]);
t2 = vec_mergeh(src[2], src[3]);
t3 = vec_mergeh(src[4], src[5]);
t4 = vec_mergeh(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset);
vec_xst(t6, 0, vecOffset + 4);
vec_xst(t7, 0, vecOffset + 8);
vec_xst(t8, 0, vecOffset + 12);
t1 = vec_mergel(src[0], src[1]);
t2 = vec_mergel(src[2], src[3]);
t3 = vec_mergel(src[4], src[5]);
t4 = vec_mergel(src[6], src[7]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, vecOffset + 16);
vec_xst(t6, 0, vecOffset + 20);
vec_xst(t7, 0, vecOffset + 24);
vec_xst(t8, 0, vecOffset + 28);
}
void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
int64_t i, j;
float * aoffsets[8];
float *aoffset = NULL, *boffset = NULL;
float * aoffset = NULL, * boffset = NULL;
__vector_pair arr[8];
vector float c[8][2] = {0};
vector float c1[8] = {0};
vector float c2[8] = {0};
aoffset = const_cast<float*>(a);
aoffset = const_cast<float *>(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffsets[0] = aoffset;
for (int it = 1; it< 8; it++)
for (int it = 1; it < 8; it++)
aoffsets[it] = aoffsets[it-1] + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
do {
for (int it = 0; it< 8; it++) {
for (int it = 0; it < 8; it++) {
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
c1[it] = c[it][0];
@ -2264,11 +2287,14 @@ class tinyBLAS_PPC {
}
vector_permute_store_8(c1, boffset);
vector_permute_store_8(c2, boffset+32);
for (int it = 0; it < 4; it++)
aoffsets[it] = aoffsets[it] + 8*lda;
vector_permute_store_8(c2, boffset + 32);
boffset += 64;
i--;
if (i > 0) {
for (int it = 0; it < 8; it++) {
aoffsets[it] = aoffsets[it] + 8;
}
}
} while(i > 0);
}
if (cols & 4) {
@ -2295,9 +2321,9 @@ class tinyBLAS_PPC {
c2[it] = c[it][1];
}
vector_permute_store_4(c1, boffset);
vector_permute_store_4(c2, boffset+16);
vector_permute_store_4(c2, boffset + 16);
for (int it = 0; it < 4; it++)
aoffsets[it] += 8*lda;
aoffsets[it] += 8 * lda;
boffset += 32;
i--;
} while(i > 0);
@ -2325,15 +2351,15 @@ class tinyBLAS_PPC {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
for (int l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
for (int l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
save_acc(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
@ -2341,9 +2367,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
for (int64_t l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@ -2353,8 +2379,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii, jj + 4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
@ -2362,9 +2388,9 @@ class tinyBLAS_PPC {
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
for (int64_t l = 0; l < k; l += 4) {
packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@ -2374,8 +2400,8 @@ class tinyBLAS_PPC {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii+4, jj);
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii + 4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
@ -2386,19 +2412,96 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
}
}
save_acc(&acc_0, ii, jj);
save_acc(&acc_1, ii, jj + 4);
save_acc(&acc_2, ii + 4, jj);
save_acc(&acc_3, ii + 4, jj + 4);
}
inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
for (int x = 0; x < 16; x += 2) {
__builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
__builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
}
}
void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
for (int64_t i = 0; i < mc; i += 16) {
int A_base_addr = (mc / 8) * (i / 8) * 16;
for (int64_t j = 0; j < nc; j += 8) {
int B_base_addr = (nc / 8) * (j / 8) * 16;
acc_t acc[8];
vec_t A0_block[16]; vec_t A1_block[16];
for (int x = 0; x < 8; x++)
__builtin_mma_xxsetaccz(&acc[x]);
for (int64_t l = 0; l < kc; l += 8) {
int A0_block_idx = A_base_addr + (l / 8) * 16;
int A1_block_idx = A0_block_idx + (mc / 8) * 16;
int B_block_idx = B_base_addr + (l / 8) * 16;
vec_t* A0_block = &vec_A[A0_block_idx];
vec_t* A1_block = &vec_A[A1_block_idx];
vec_t* B_block = &vec_B[B_block_idx];
MMA_16x8(A0_block, A1_block, B_block, acc);
}
if (kk == 0) {
save_acc(&acc[0], ii + i, jj + j);
save_acc(&acc[1], ii + i, jj + j + 4);
save_acc(&acc[2], ii + i + 4, jj + j);
save_acc(&acc[3], ii + i + 4, jj + j + 4);
save_acc(&acc[4], ii + i + 8, jj + j);
save_acc(&acc[5], ii + i + 8, jj + j + 4);
save_acc(&acc[6], ii + i + 12, jj + j);
save_acc(&acc[7], ii + i + 12, jj + j + 4);
} else {
add_save_acc(&acc[0], ii + i, jj + j);
add_save_acc(&acc[1], ii + i, jj + j + 4);
add_save_acc(&acc[2], ii + i + 4, jj + j);
add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
add_save_acc(&acc[4], ii + i + 8, jj + j);
add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
add_save_acc(&acc[6], ii + i + 12, jj + j);
add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
}
}
}
}
void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
int64_t ytiles = m / mc;
int64_t xtiles = n / nc;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles) {
end = tiles;
}
for (int64_t job = start; job < end; ++job) {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
vec_t A_pack[kc * mc / 4];
vec_t B_pack[kc * nc / 4];
packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
}
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
SAVE_ACC(&acc_2, ii+4, jj);
SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
@ -2406,35 +2509,35 @@ class tinyBLAS_PPC {
int n_rem = MIN(n - n0, 8);
int mc = 0, nc = 0;
if (m_rem >= 8 && n_rem >= 8) {
mc = 8;
nc = 8;
gemm<8, 8>(m0, m, n0, n);
mc = 8;
nc = 8;
gemm<8, 8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
mc = 4;
nc = 8;
gemm<4, 8>(m0, m, n0, n);
mc = 4;
nc = 8;
gemm<4, 8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
mc = 8;
nc = 4;
gemm<8, 4>(m0, m, n0, n);
mc = 8;
nc = 4;
gemm<8, 4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
} else {
mc = (m_rem >= 4) ? 4 : m_rem;
nc = (n_rem >= 4) ? 4 : n_rem;
if (mc == 0 || nc == 0)
return;
return;
gemm_small(m0, m, n0, n, mc, nc);
}
int64_t mp = m0 + ((m - m0) / mc) * mc;
int64_t np = n0 + ((n - n0) / nc) * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
}
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
@ -2449,30 +2552,30 @@ class tinyBLAS_PPC {
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4] {0}, vec_B[4] = {0};
for (int l=0; l<k; l+=4) {
vec_t vec_A[4] = {0}, vec_B[4] = {0};
for (int l = 0; l < k; l += 4) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
float* a = const_cast<float*>(A+(ii)*lda+l);
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
float * a = const_cast<float *>(A + (ii) * lda + l);
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
} else if (RN == 1) {
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
float* b = const_cast<float*>(B+(jj)*ldb+l);
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
float * b = const_cast<float *>(B + (jj) * ldb + l);
vec_B[0] = (vec_t)vec_xl(0,b);
vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
} else {
packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@ -2482,12 +2585,27 @@ class tinyBLAS_PPC {
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
*((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
}
}
}
}
template<int RM, int RN>
inline void kernel(int64_t ii, int64_t jj) {
if constexpr(RM == 4 && RN == 4) {
KERNEL_4x4(ii, jj);
} else if constexpr(RM == 4 && RN == 8) {
KERNEL_4x8(ii, jj);
} else if constexpr(RM == 8 && RN == 4) {
KERNEL_8x4(ii, jj);
} else if constexpr(RM == 8 && RN == 8) {
KERNEL_8x8(ii, jj);
} else {
static_assert(false, "RN/RM values not supported");
}
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
@ -2496,27 +2614,18 @@ class tinyBLAS_PPC {
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (RM == 4 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_4x4;
} else if (RM == 4 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_4x8;
} else if (RM == 8 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_8x4;
} else if (RM == 8 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_8x8;
}
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
(this->*kernel)(ii, jj);
kernel<RM, RN>(ii, jj);
}
}
const float *const A;
const float *const B;
float *C;
const float * const A;
const float * const B;
float * C;
const int64_t k;
const int64_t lda;
const int64_t ldb;

View file

@ -9072,6 +9072,9 @@ static void ggml_compute_forward_ssm_scan_f32(
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#elif defined(__riscv_v_intrinsic)
// todo: RVV implementation
const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
@ -10023,8 +10026,8 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
int64_t h_stride_2d = head_size * head_size;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar Route to scalar implementation //TODO: Write SVE code
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
for (int64_t t = 0; t < T; t++) {
int64_t t_offset = t * t_stride;
int64_t state_offset = head_size * C * (t / (T / n_seqs));

View file

@ -18,6 +18,10 @@
#include <immintrin.h>
#endif
#if defined(__riscv_v_intrinsic)
#include <riscv_vector.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -94,24 +98,15 @@ extern "C" {
}
#elif defined(__riscv) && defined(__riscv_zfhmin)
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
float f;
__asm__(
"fmv.h.x %[f], %[h]\n\t"
"fcvt.s.h %[f], %[f]"
: [f] "=&f" (f)
: [h] "r" (h)
);
return f;
_Float16 hf;
memcpy(&hf, &h, sizeof(ggml_fp16_t));
return hf;
}
static inline ggml_fp16_t riscv_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
__asm__(
"fcvt.h.s %[f], %[f]\n\t"
"fmv.x.h %[h], %[f]"
: [h] "=&r" (res)
: [f] "f" (f)
);
_Float16 hf = (_Float16)f;
memcpy(&res, &hf, sizeof(ggml_fp16_t));
return res;
}
@ -1170,6 +1165,36 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
#elif defined(__riscv_v_intrinsic)
// compatible with vlen >= 128
#define GGML_SIMD
// F32
#define GGML_F32_STEP 16
#define GGML_F32_EPR 4
#define GGML_F32x4 vfloat32m1_t
#define GGML_F32x4_ZERO __riscv_vfmv_v_f_f32m1(0.0f, GGML_F32_EPR)
#define GGML_F32x4_SET1(x) __riscv_vfmv_v_f_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_LOAD(x) __riscv_vle32_v_f32m1(x, GGML_F32_EPR)
#define GGML_F32x4_STORE(b, v) __riscv_vse32_v_f32m1(b, v, GGML_F32_EPR)
#define GGML_F32x4_FMA(a, b, c) __riscv_vfmacc_vv_f32m1(a, b, c, GGML_F32_EPR)
#define GGML_F32x4_ADD(a, b) __riscv_vfadd_vv_f32m1(a, b, GGML_F32_EPR)
#define GGML_F32x4_MUL(a, b) __riscv_vfmul_vv_f32m1(a, b, GGML_F32_EPR)
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
#endif
// GGML_F32_ARR / GGML_F16_ARR

View file

@ -84,6 +84,16 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
#elif defined(__riscv_v_intrinsic)
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
}
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -197,7 +207,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0;
#if defined(GGML_SIMD)
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
@ -325,6 +335,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
vst1q_f32(y + i, val);
sum += (ggml_float)vaddvq_f32(val);
}
#elif defined(__riscv_v_intrinsic)
vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
for (int avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
__riscv_vse32_v_f32m2(&y[i], val, avl);
vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
}
return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
#endif
for (; i < n; ++i) {
float val = expf(x[i] - max);

View file

@ -119,6 +119,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
@ -149,6 +157,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@ -243,6 +252,14 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
svst1_f32(pg, y + np2, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -276,6 +293,13 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -297,6 +321,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -324,6 +349,16 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
y[i] += x[k][i]*v[k][0];
}
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
for (int k = 0; k < GGML_VEC_MAD_UNROLL; k++) {
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
}
__riscv_vse32_v_f32m8(&y[i], ay, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -375,6 +410,14 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -436,6 +479,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
ay1 = svmul_f32_m(pg, ay1, vx);
svst1_f32(pg, y + np, ay1);
}
#elif defined(__riscv_v_intrinsic)
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
__riscv_vse32_v_f32m8(&y[i], ny, avl);
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -467,6 +517,13 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -486,6 +543,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -928,7 +986,51 @@ inline static __m128 ggml_v_silu(__m128 x) {
return _mm_div_ps(x, one_plus_exp_neg_x);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__
#elif defined(__riscv_v_intrinsic)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
// numbers above 88.38 will flush to infinity
// numbers beneath -103.97 will flush to zero
inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
#ifdef __riscv_xtheadvector
// workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
#else
const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
#endif
const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
0x1.7f7d1cp-20f, n, vl);
const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
__riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
__riscv_vfmacc_vv_f32m2(
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
__riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
u, vl), u, vl);
if (!__riscv_vcpop_m_b16(c, vl))
return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
__riscv_vfmacc_vv_f32m2(k, k, j, vl),
__riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
c, vl);
return __riscv_vmerge_vvm_f32m2(
r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
__riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
for (int i = 0; i < n; ++i) {

View file

@ -107,9 +107,10 @@ constexpr bool ggml_cuda_has_arch(const int arch) {
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
}
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
if (cur == 0) {
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
GGML_LOG_WARN("ggml was not compiled with any CUDA arch warning");
return -1;
}
return cur;
}

View file

@ -249,6 +249,7 @@ typedef struct {
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
@ -257,6 +258,11 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;
typedef struct {
int32_t ne00;
int32_t ne02;

View file

@ -291,6 +291,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@ -575,6 +579,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@ -1324,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@ -1609,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
const int ne11_mm_min = 4;
const int ne11_mm_min = 8;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
src0t == GGML_TYPE_F16 || // TODO: helper function
src0t == GGML_TYPE_F32 || // TODO: helper function
src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
// num threads along row per simdgroup
int nxpsg = 0;
if (ne00 % 256 == 0 && ne11 < 3) {
nxpsg = 16;
} else if (ne00 % 128 == 0) {
nxpsg = 8;
} else {
nxpsg = 4;
}
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
const size_t smem = FATTN_SMEM(nsg);
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
@ -5561,15 +5593,17 @@ static int ggml_metal_encode_node(
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
// ne00*(nsg)
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
if (smem > device.maxThreadgroupMemoryLength) {
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
@ -5577,7 +5611,7 @@ static int ggml_metal_encode_node(
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@ -5585,13 +5619,74 @@ static int ggml_metal_encode_node(
}
nsg /= 2;
const size_t smem = FATTN_SMEM(nsg);
// workgroups
// each workgroup handles nsg*nkpsg cache values
uint16_t nwg = 1;
if (4*nsg*nkpsg >= ne11) {
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
// using 1 workgroup -> write the result directly into dst
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
nwg = 32;
nsg = MIN(4, nsg);
const size_t smem = FATTN_SMEM(nsg);
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
// sanity checks
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the head vector
// - + 2: the S and M values for each intermediate result
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
if (!h_tmp) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
return 0;
}
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
[encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// reduce the results from the workgroups
{
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
nrows,
ne20,
};
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:h_tmp offset:0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
}
}
#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:

View file

@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
@ -3015,7 +3020,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
}
}
@ -3200,6 +3204,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
@ -4786,14 +4795,16 @@ kernel void kernel_flash_attn_ext_vec(
device const char * mask,
device const char * sinks,
device char * dst,
constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iwg = tgpig[2]%nwg;
const int iq3 = tgpig[2];
const int iq3 = tgpig[2]/nwg;
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
@ -4872,7 +4883,7 @@ kernel void kernel_flash_attn_ext_vec(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
@ -5002,7 +5013,7 @@ kernel void kernel_flash_attn_ext_vec(
}
}
if (sinks != q && sgitg == 0) {
if (sinks != q && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
@ -5111,14 +5122,25 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
const int64_t nrows = args.ne3*args.ne2*args.ne1;
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
device float4 * dst4 = (device float4 *) dst;
device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
// interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
}
// store S and M
if (nwg > 1 && tiisg == 0) {
dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
}
}
}
@ -5218,6 +5240,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
#undef FA_TYPES
kernel void kernel_flash_attn_ext_reduce(
constant ggml_metal_kargs_flash_attn_ext_reduce & args,
device const char * htmp,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const uint64_t rid = tgpig;
const short nwg = 32;
const short iwg = tiisg;
const short DV = args.ne20;
const short DV4 = DV/4;
device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
device float4 * dst4 = (device float4 *) dst + rid*DV4;
float S = ss[rid*(2*nwg) + 2*iwg + 0];
float M = ss[rid*(2*nwg) + 2*iwg + 1];
const float m = simd_max(M);
const float ms = exp(M - m);
S = 1.0f/simd_sum(S*ms);
for (int i = sgitg; i < DV4; i += nwg) {
const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
if (iwg == 0) {
dst4[i] = v*S;
}
}
}
template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,

View file

@ -1376,7 +1376,7 @@ ggml_tensor * llm_graph_context::build_attn(
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs());
assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;