diff --git a/CMakeLists.txt b/CMakeLists.txt index bc4918c0a..9b6a725d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,7 +124,7 @@ if (LLAMA_CUBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard + # 50 == lowest CUDA 12 standard # 60 == f16 CUDA intrinsics # 61 == integer CUDA intrinsics # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster @@ -135,9 +135,9 @@ if (LLAMA_CUBLAS) message("CUDA Toolkit Version: ${CUDAToolkit_VERSION}") if(CUDAToolkit_VERSION VERSION_GREATER 12) add_compile_definitions(GGML_CUDA_USE_GRAPHS) #try enable cuda graphs on cu12 build - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics else() - set(CMAKE_CUDA_ARCHITECTURES "37;52;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics endif() endif() endif() diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 769bac17d..3ddd4ce5b 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -2811,9 +2811,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (!ctx->has_glm_projector) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + // The patches vector is used to get rows to index into the embeds with; + // we should skip dim 0 only if we have CLS to avoid going out of bounds + // when retrieving the rows. + int patch_offset = ctx->has_class_embedding ? 1 : 0; int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + patches_data[i] = i + patch_offset; } ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); free(patches_data); diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz index 3acd603ab..e6a22a4e3 100644 Binary files a/examples/server/public/index.html.gz and b/examples/server/public/index.html.gz differ diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 6f8ab2b93..4db4c783a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -7,6 +7,8 @@ // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// disable Nagle's algorithm +#define CPPHTTPLIB_TCP_NODELAY true #include "httplib.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: diff --git a/examples/server/webui/src/components/ChatScreen.tsx b/examples/server/webui/src/components/ChatScreen.tsx index 80012c3e4..d7a246cf6 100644 --- a/examples/server/webui/src/components/ChatScreen.tsx +++ b/examples/server/webui/src/components/ChatScreen.tsx @@ -228,6 +228,7 @@ export default function ChatScreen() { value={inputMsg} onChange={(e) => setInputMsg(e.target.value)} onKeyDown={(e) => { + if (e.nativeEvent.isComposing || e.keyCode === 229) return; if (e.key === 'Enter' && e.shiftKey) return; if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); diff --git a/examples/server/webui/src/utils/llama-vscode.ts b/examples/server/webui/src/utils/llama-vscode.ts index 6c23221c4..76cc553a2 100644 --- a/examples/server/webui/src/utils/llama-vscode.ts +++ b/examples/server/webui/src/utils/llama-vscode.ts @@ -40,7 +40,7 @@ export const useVSCodeContext = ( window.addEventListener('message', handleMessage); return () => window.removeEventListener('message', handleMessage); - }, []); + }, [inputRef, setInputMsg]); // Add a keydown listener that sends the "escapePressed" message to the parent window useEffect(() => { diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index d23c6b262..9b8a69754 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -95,6 +95,7 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); GGML_BACKEND_API int ggml_cpu_has_sve (void); GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + GGML_BACKEND_API int ggml_cpu_has_sme (void); // other GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index ffef2130b..f0ac76861 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -5113,7 +5113,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const int nb = n / QK_K; -#ifdef __ARM_NEON +#if defined(__ARM_FEATURE_SVE) + + uint32_t utmp[4]; + + const int8_t m32 = 32; + const int vector_length = svcntb()*8; + const svuint8_t m3b_sv = svdup_n_u8(0x3); + const svint32_t vzero_sv = svdup_n_s32(0); + + const svuint8_t m0_sv = svdup_n_u8(1); + const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); + const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); + const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); + svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4)); + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3_sv = x[i].qs; + const uint8_t * restrict qh_sv = x[i].hmask; + const int8_t * restrict q8_sv = y[i].qs; + + // Set up scales + uint32_t * aux = &x[i].scales; + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + switch (vector_length) { + case 128: + { + svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv); + svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + if (j == 0) { + qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4); + qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_b32(), sumi1_1)); + } break; + case 256: + case 512: + { + svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + + svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + if (j == 0) { + qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sum; + +#elif __ARM_NEON uint32_t aux[3]; uint32_t utmp[4]; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 179b756dd..e4ba0c12f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -117,7 +117,8 @@ struct ggml_arm_arch_features_type { int has_i8mm; int has_sve; int sve_cnt; -} ggml_arm_arch_features = {-1, -1, -1, -1, 0}; + int has_sme; +} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1}; #endif @@ -2387,15 +2388,20 @@ bool ggml_is_numa(void) { #define HWCAP2_I8MM (1 << 13) #endif +#if !defined(HWCAP2_SME) +#define HWCAP2_SME (1 << 23) +#endif + static void ggml_init_arm_arch_features(void) { #if defined(__linux__) && defined(__aarch64__) uint32_t hwcap = getauxval(AT_HWCAP); uint32_t hwcap2 = getauxval(AT_HWCAP2); - ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); - ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); - ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME); #if defined(__ARM_FEATURE_SVE) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); @@ -2418,6 +2424,11 @@ static void ggml_init_arm_arch_features(void) { } ggml_arm_arch_features.has_i8mm = oldp; + if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_sme = oldp; + ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #else @@ -2441,6 +2452,12 @@ static void ggml_init_arm_arch_features(void) { ggml_arm_arch_features.has_sve = 0; ggml_arm_arch_features.sve_cnt = 0; #endif + +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2) + ggml_arm_arch_features.has_sme = 1; +#else + ggml_arm_arch_features.has_sme = 0; +#endif #endif } #endif @@ -14487,6 +14504,14 @@ int ggml_cpu_get_sve_cnt(void) { #endif } +int ggml_cpu_has_sme(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME) + return ggml_arm_arch_features.has_sme; +#else + return 0; +#endif +} + void ggml_cpu_init(void) { // needed to initialize f16 tables { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index fcae9387c..a198fa7f5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -14,6 +14,10 @@ #include "ggml-cpu-hbm.h" #endif +#ifdef GGML_USE_CPU_KLEIDIAI +#include "kleidiai/kleidiai.h" +#endif + #if defined(__APPLE__) #include #include @@ -39,6 +43,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type // } // #endif +#ifdef GGML_USE_CPU_KLEIDIAI + if (ggml_backend_cpu_kleidiai_buffer_type()) { + bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_AARCH64 if (ggml_backend_cpu_aarch64_buffer_type()) { bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); @@ -538,6 +548,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); } + if (ggml_cpu_has_sme()) { + features.push_back({ "SME", "1" }); + } if (ggml_cpu_has_riscv_v()) { features.push_back({ "RISCV_V", "1" }); } @@ -559,6 +572,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r #ifdef GGML_USE_OPENMP features.push_back({ "OPENMP", "1" }); #endif + #ifdef GGML_USE_CPU_KLEIDIAI + features.push_back({ "KLEIDIAI", "1" }); + #endif #ifdef GGML_USE_CPU_AARCH64 features.push_back({ "AARCH64_REPACK", "1" }); #endif diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp new file mode 100644 index 000000000..a8a59a887 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +// KleidiAI micro-kernels +#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_common.h" + +#include "kernels.h" + +#define NELEMS(x) sizeof(x) / sizeof(*x) +static ggml_kleidiai_kernels gemm_gemv_kernels[] = { +#if defined(__ARM_FEATURE_SME) + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .require_aligned_m_idx = */ true, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + }, +#endif +#if defined(__APPLE__) +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#else +#if defined(__ARM_FEATURE_MATMUL_INT8) + { + /* i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + }, + /* i8mm GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + }, +#endif +#if defined(__ARM_FEATURE_DOTPROD) + { + /* DOTPROD GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + }, + /* DOTPROD GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, + /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .require_aligned_m_idx = */ false, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + }, + /* .required_cpu = */ CPU_FEATURE_DOTPROD, + }, +#endif +#endif +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { + ggml_kleidiai_kernels * kernels = nullptr; + + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { + kernels = &gemm_gemv_kernels[i]; + break; + } + } + + return kernels; +} diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h new file mode 100644 index 000000000..a0b0d1493 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +enum cpu_feature { + CPU_FEATURE_NONE = 0, + CPU_FEATURE_DOTPROD = 1, + CPU_FEATURE_I8MM = 2, + CPU_FEATURE_SVE = 4, + CPU_FEATURE_SME = 8 +}; +inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { + lhs = static_cast(lhs | rhs); + return lhs; +} +inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) { + return static_cast(static_cast(lhs) | static_cast(rhs)); +} + +struct kernel_info { + size_t (*get_m_step)(void); + size_t (*get_n_step)(void); + size_t (*get_mr)(void); + size_t (*get_nr)(void); + size_t (*get_kr)(void); + size_t (*get_sr)(void); + size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); + size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); + size_t (*get_dst_size)(size_t m, size_t n); + void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, + float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); +}; + +struct lhs_packing_info { + size_t (*get_offset)(size_t m_idx, size_t lhs_stride); + size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, + size_t lhs_stride, void* lhs_packed); + bool require_aligned_m_idx; +}; + +struct rhs_packing_info { + size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, + const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); +}; + +struct ggml_kleidiai_kernels { + kernel_info gemm; + kernel_info gemv; + lhs_packing_info lhs_info; + rhs_packing_info rhs_info; + + cpu_feature required_cpu; +}; + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp new file mode 100644 index 000000000..66685fd16 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -0,0 +1,287 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#include +#elif defined(__APPLE__) +#include +#include +#include +#elif defined(_WIN32) +#include +#include +#endif + +#include "kleidiai.h" + +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml-threading.h" +#include "ggml-cpu-traits.h" + +#include "kernels.h" + +#include "kai_common.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + +struct ggml_kleidiai_context { + ggml_kleidiai_kernels * kernels; +} static ctx = { NULL }; + +static void init_kleidiai_context(void) { + + ggml_critical_section_start(); + static bool initialized = false; + + if (!initialized) { + initialized = true; + const char *env_var = getenv("GGML_KLEIDIAI_SME"); + int sme_enabled = 0; + + cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + + if (env_var) { + sme_enabled = atoi(env_var); + } + + if (sme_enabled != 0) { + features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + } + ctx.kernels = ggml_kleidiai_select_kernels(features); + } + ggml_critical_section_end(); +} + +static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + return tensor->ne[dim]; +} + +namespace ggml::cpu::kleidiai { +class tensor_traits : public ggml::cpu::tensor_traits { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + + size_t k = op->src[0]->ne[0]; + size_t m = op->src[1]->ne[1]; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + + return true; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { + if (dst->op == GGML_OP_MUL_MAT) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ctx.kernels); + kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + // Calculate number of columns to be processed per thread + const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true; + const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if(m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + return true; + } + return false; + } + +public: + int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(ctx.kernels); + const size_t n = tensor->ne[1]; + const size_t k = tensor->ne[0]; + size_t nr = ctx.kernels->gemm.get_nr(); + size_t kr = ctx.kernels->gemm.get_kr(); + size_t sr = ctx.kernels->gemm.get_sr(); + +#ifndef NDEBUG + const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); +#endif + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + + return 0; + + GGML_UNUSED(data_size); + } +}; + +static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { + static tensor_traits traits; + return &traits; +} +} // namespace ggml::cpu::kleidiai + +static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + GGML_ASSERT(OK == 0); + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_KLEIDIAI"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::kleidiai { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + if ( op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32 && + ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + return true; + } + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::kleidiai + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { + static ggml::cpu::kleidiai::extra_buffer_type ctx; + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ &ctx, + }; + + init_kleidiai_context(); + + return &ggml_backend_cpu_buffer_type_kleidiai; +} diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h new file mode 100644 index 000000000..38eac58f7 --- /dev/null +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4a92d35f9..7e99838c0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -411,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) return __dp4a(a, b, c); -#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) const int8_t * a8 = (const int8_t *) &a; const int8_t * b8 = (const int8_t *) &b; return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 54c0f66d2..cca2bee0b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,5 @@ #include "cpy.cuh" +#include "dequantize.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { - const block_q8_0 * xi = (const block_q8_0 *) cxi; - float * dsti = (float *) cdsti; + float * cdstf = (float *)(cdsti); - const float d = (float)xi->d; - - for (int j = 0; j < QK8_0; j++) { - dsti[j] = xi->qs[j] * d; +#pragma unroll + for (int j = 0; j < QK8_0; j += 2) { + dfloat2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + 1) = dq.y; } } @@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { memcpy(dsti->qh, &qh, sizeof(qh)); } +template +static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); + +#pragma unroll + for (int j = 0; j < qk/2; j++) { + dfloat2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + qk/2) = dq.y; + } +} static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK4_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q4_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK4_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q5_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q5_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK5_0><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q5_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q5_1_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, + const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream) { + const int num_blocks = ne; + cpy_q_f32, QK5_1><<>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_iq4_nl_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { @@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK4_1>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_0>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 27599a2b0..0ce4afbb2 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,13 +123,13 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); shared_memory_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -175,13 +175,13 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); shared_memory_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b99ff1f71..65e092486 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -264,6 +264,12 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no", prop.warpSize); +#elif defined(GGML_USE_MUSA) + // TODO: refine the .cc to reflect MUSA's actual CC capabilities + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; @@ -1783,9 +1789,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else -#ifdef GGML_USE_MUSA - GGML_ASSERT(false); -#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1828,7 +1831,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { @@ -3078,15 +3080,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { return true; } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { return true; } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { return true; } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { return true; } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index eab017889..1fbcbd045 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -16,7 +16,7 @@ #include #endif // __ARM_FEATURE_SVE -#if defined(__ARM_NEON) && !defined(__CUDACC__) +#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__) // if YCM cannot find , make a symbolic link to it, for example: // // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7d26a70f2..236b2b8dc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1439,6 +1439,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); } + // skip unused tensors + if (info.op == GGML_OP_NONE) { + LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str()); + ml.n_created++; + + return nullptr; + } + // tensors with "bias" suffix are always used with GGML_OP_ADD ggml_op op; bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;