diff --git a/common/arg.cpp b/common/arg.cpp index d52f8bff4..f2c2255c4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1532,6 +1532,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.ctx_shift = false; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); + add_opt(common_arg( + {"--context-shift"}, + string_format("enables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](common_params & params) { + params.ctx_shift = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_CONTEXT_SHIFT")); add_opt(common_arg( {"--chunks"}, "N", string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), @@ -1825,7 +1832,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.sampling.top_n_sigma = std::stof(value); } - ).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam()); + ).set_sparam()); add_opt(common_arg( {"--xtc-probability"}, "N", string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), diff --git a/common/chat.cpp b/common/chat.cpp index e0aea958d..cb2f40523 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -632,7 +632,6 @@ const char * common_reasoning_format_name(common_reasoning_format format) { case COMMON_REASONING_FORMAT_AUTO: return "auto"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; - case COMMON_REASONING_FORMAT_GRANITE: return "granite"; default: throw std::runtime_error("Unknown reasoning format"); } diff --git a/common/common.h b/common/common.h index 14d92af0e..5b0502ac2 100644 --- a/common/common.h +++ b/common/common.h @@ -235,12 +235,15 @@ struct common_params_diffusion { bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0 }; +// reasoning API response format (not to be confused as chat template's reasoning format) enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_AUTO, + COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content` COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. - COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. + // do not extend this enum unless you absolutely have to + // in most cases, use COMMON_REASONING_FORMAT_AUTO + // see: https://github.com/ggml-org/llama.cpp/pull/15408 }; @@ -368,7 +371,7 @@ struct common_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics - bool ctx_shift = true; // context shift on inifinite text generation + bool ctx_shift = false; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index f47612799..0bfb92df1 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -73,7 +73,6 @@ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index 49aae7a23..d3dfd049e 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char vshift4 = vec_splats((unsigned char)4); + vector float vsumf0 = vec_splats(0.0f); + + vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) * + GGML_E8M0_TO_FP32_HALF(x[ib].e)); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs); + + vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4); + + vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles); + vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + sumf = vec_extract(vsumf0, 0); + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu index 8bed62ac9..8d9cf692b 100644 --- a/ggml/src/ggml-cuda/add-id.cu +++ b/ggml/src/ggml-cuda/add-id.cu @@ -11,14 +11,14 @@ static __global__ void add_id_kernel( const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; - const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21); const size_t nb1 = ne0 * sizeof(float); const size_t nb2 = ne1 * nb1; float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); - const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); - const float * src1_row = (const float *)((char *)src1 + i11*nb11); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((const char *)src1 + i11*nb11); for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { dst_row[i0] = src0_row[i0] + src1_row[i0]; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d4409aff1..2b77183dd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -78,6 +78,8 @@ #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads +#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons + #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 #define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD @@ -494,13 +496,14 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) } -#if CUDART_VERSION < CUDART_HMASK +#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \ + (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < CUDART_HMASK +#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIP) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 39731baae..1d7e0b037 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1237,10 +1237,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); - GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(mask_h2); GGML_UNUSED(sinks_f); + GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(stride_Q1); GGML_UNUSED(stride_Q2); + GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE @@ -1395,8 +1397,8 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 660849f81..12da30df2 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -299,17 +299,17 @@ static __global__ void flash_attn_tile_ext_f16( } } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index c58194937..1c1dc725d 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -38,17 +38,6 @@ static __global__ void flash_attn_tile_ext_f32( return; #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); - GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); - GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; return; } @@ -313,7 +302,7 @@ static __global__ void flash_attn_tile_ext_f32( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 36295fe95..7f4454c14 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -349,8 +349,8 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d6d0bfb74..a06fba6cd 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -37,17 +37,6 @@ static __global__ void flash_attn_vec_ext_f32( // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); NO_DEVICE_CODE; return; } @@ -346,8 +335,8 @@ static __global__ void flash_attn_vec_ext_f32( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index ea855a188..3449648ef 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -471,9 +471,9 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(sinks); GGML_UNUSED(KV_max); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 1437367e8..5c66fe5bb 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,6 @@ static void mul_mat_f_cuda( cudaStream_t stream) { typedef tile<16, 8, T> tile_A; typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; GGML_ASSERT(!ids && "mul_mat_id not implemented"); @@ -352,9 +351,6 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); GGML_ASSERT( nb0 == ts_dst); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index ea92347f0..f0011aa56 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2856,12 +2856,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( #else typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; -#endif +#endif // defined(AMD_MFMA_AVAILABLE) constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bee20413..6bcae9e52 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -39,7 +39,7 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } __syncthreads(); sum = 0.0f; - if (lane_id < (blockDim.x / WARP_SIZE)) { + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { sum = s_sum[lane_id]; } sum = warp_reduce_sum(sum); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 94f6405ca..727932123 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -566,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8 for (int i = 0; i < n; ++i) { L[i] += nmax; } - return sumlx / suml2; + return suml2 > 0.0f ? sumlx / suml2 : 0.0f; } for (int i = 0; i < n; ++i) { int l = nearest_int(iscale * x[i]); @@ -901,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint for (int i = 0; i < n; ++i) { max = MAX(max, x[i]); } - if (!max) { // all zero + if (max < GROUP_MAX_EPS) { // all zero for (int i = 0; i < n; ++i) { L[i] = 0; } return 0.f; } @@ -966,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint break; } } - return sumlx/suml2; + return suml2 > 0.0f ? sumlx / suml2 : 0.0f; } static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) { @@ -4266,7 +4266,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R sumw[j+1] = sumw[j] + weight[i]; } } - float best_score = -FLT_MIN, scale = max; + float best_score = -FLT_MAX, scale = max; int besti1 = -1, besti2 = -1, best_shift = 0; for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { @@ -4442,7 +4442,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = -FLT_MIN, scale = max; + float best_score = -FLT_MAX, scale = max; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 62eb478ba..98291717b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -357,6 +357,12 @@ enum vk_conv_shapes { CONV_SHAPE_COUNT, }; +enum dmmv_wg_sizes { + DMMV_WG_SIZE_SUBGROUP, + DMMV_WG_SIZE_LARGE, + DMMV_WG_SIZE_COUNT, +}; + static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); @@ -447,8 +453,8 @@ struct vk_device_struct { vk_pipeline pipeline_quantize_q8_1; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; @@ -2785,54 +2791,61 @@ static void ggml_vk_load_shaders(vk_device& device) { rm_stdq = 2; uint32_t rm_iq = 2 * rm_kq; - for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { + uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_16 : (subgroup_size_16 * 4); + uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? device->subgroup_size : (device->subgroup_size * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + const bool s = device->subgroup_add && device->architecture != vk_device_architecture::AMD_GCN; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f32_f32_f32_len[s], arr_dmmv_f32_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f16_f32_f32_len[s], arr_dmmv_f16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_bf16_f32_f32_len[s], arr_dmmv_bf16_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_0_f32_f32_len[s], arr_dmmv_q4_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_1_f32_f32_len[s], arr_dmmv_q4_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_0_f32_f32_len[s], arr_dmmv_q5_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_1_f32_f32_len[s], arr_dmmv_q5_1_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q8_0_f32_f32_len[s], arr_dmmv_q8_0_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q2_k_f32_f32_len[s], arr_dmmv_q2_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q3_k_f32_f32_len[s], arr_dmmv_q3_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_k_f32_f32_len[s], arr_dmmv_q4_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_k_f32_f32_len[s], arr_dmmv_q5_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q6_k_f32_f32_len[s], arr_dmmv_q6_k_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_s_f32_f32_len[s], arr_dmmv_iq1_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_m_f32_f32_len[s], arr_dmmv_iq1_m_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xxs_f32_f32_len[s], arr_dmmv_iq2_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xs_f32_f32_len[s], arr_dmmv_iq2_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_s_f32_f32_len[s], arr_dmmv_iq2_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_xxs_f32_f32_len[s], arr_dmmv_iq3_xxs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_s_f32_f32_len[s], arr_dmmv_iq3_s_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_xs_f32_f32_len[s], arr_dmmv_iq4_xs_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_nl_f32_f32_len[s], arr_dmmv_iq4_nl_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_mxfp4_f32_f32_len[s], arr_dmmv_mxfp4_f32_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f32_f16_f32_len[s], arr_dmmv_f32_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_f16_f16_f32_len[s], arr_dmmv_f16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_bf16_f16_f32_len[s], arr_dmmv_bf16_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_0_f16_f32_len[s], arr_dmmv_q4_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_1_f16_f32_len[s], arr_dmmv_q4_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_0_f16_f32_len[s], arr_dmmv_q5_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_1_f16_f32_len[s], arr_dmmv_q5_1_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q8_0_f16_f32_len[s], arr_dmmv_q8_0_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q2_k_f16_f32_len[s], arr_dmmv_q2_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q3_k_f16_f32_len[s], arr_dmmv_q3_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q4_k_f16_f32_len[s], arr_dmmv_q4_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q5_k_f16_f32_len[s], arr_dmmv_q5_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_q6_k_f16_f32_len[s], arr_dmmv_q6_k_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_s_f16_f32_len[s], arr_dmmv_iq1_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq1_m_f16_f32_len[s], arr_dmmv_iq1_m_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xxs_f16_f32_len[s], arr_dmmv_iq2_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_xs_f16_f32_len[s], arr_dmmv_iq2_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq2_s_f16_f32_len[s], arr_dmmv_iq2_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_xxs_f16_f32_len[s], arr_dmmv_iq3_xxs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq3_s_f16_f32_len[s], arr_dmmv_iq3_s_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_xs_f16_f32_len[s], arr_dmmv_iq4_xs_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_iq4_nl_f16_f32_len[s], arr_dmmv_iq4_nl_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(w)+"_"+std::to_string(i+1), arr_dmmv_mxfp4_f16_f32_len[s], arr_dmmv_mxfp4_f16_f32_data[s], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true); + } } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -4415,7 +4428,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); @@ -4449,7 +4462,24 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * return nullptr; } - return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if (ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { @@ -5756,7 +5786,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 903753c7e..b93e9948f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -1,6 +1,10 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#endif #ifdef MUL_MAT_ID #define EXPERT_COUNT 8 @@ -90,7 +94,38 @@ layout (constant_id = 2) const uint NUM_COLS = 1; shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; -void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { +void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // subgroupAdd is probably faster on devices that support it, + // particularly when the workgroup has more than one subgroup +#if USE_SUBGROUP_ADD + // sum up partial sums within a subgroup + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + // Go through shared memory to sum partials across subgroups + if (gl_SubgroupInvocationID == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][gl_SubgroupID] = temp[j][n]; + } + } + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + temp[j][n] += tmpsh[j][n][s]; + } + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +#else // sum up partial sums and write back result [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -115,4 +150,5 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32 } } } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4882de3d8..8a12f69c3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -237,7 +237,8 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 - std::string opt_level = coopmat ? "" : "-O"; + // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 + std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; #ifdef _WIN32 std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; @@ -486,6 +487,9 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); // Dequant shaders @@ -803,6 +807,18 @@ void write_output_files() { fputs(data.c_str(), src); fputs(len.c_str(), src); } + + for (const std::string& btype : {"f16", "f32"}) { + for (const auto& tname : type_names) { + fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[2];\n", tname.c_str(), btype.c_str()); + fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[2];\n", tname.c_str(), btype.c_str()); + std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data};\n"; + std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len};\n"; + fputs(data.c_str(), src); + fputs(len.c_str(), src); + } + } + fclose(hdr); fclose(src); } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 155e6ee70..8cccc1a4a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -145,11 +145,6 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { - LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", - __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); - } - if (!hparams.vocab_only) { // GPU backends for (auto * dev : model.devices) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8a8d2fff6..741eb8d31 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -91,6 +91,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_40B: return "40B"; case LLM_TYPE_65B: return "65B"; case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_120B: return "120B"; case LLM_TYPE_142B: return "142B"; case LLM_TYPE_236B: return "236B"; case LLM_TYPE_290B: return "290B"; @@ -1839,7 +1840,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); - // TODO: switch (hparams.n_layer) + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_LFM2: { @@ -6843,9 +6848,9 @@ struct llm_build_falcon : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -7123,9 +7128,9 @@ struct llm_build_dbrx : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -7245,13 +7250,13 @@ struct llm_build_starcoder : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7467,13 +7472,15 @@ struct llm_build_bert : public llm_graph_context { cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } if (model.layers[il].attn_q_norm) { @@ -7481,6 +7488,10 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + } else { + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); } if (model.layers[il].attn_k_norm) { @@ -7488,11 +7499,11 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } else { + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } // RoPE if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { @@ -7637,9 +7648,9 @@ struct llm_build_neo_bert : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // RoPE Qcur = ggml_rope_ext( @@ -7746,13 +7757,13 @@ struct llm_build_bloom : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7870,7 +7881,7 @@ struct llm_build_mpt : public llm_graph_context { ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -7889,17 +7900,18 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_k_norm_b, LLM_NORM, il); cb(Kcur, "Kcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { - Qcur = ggml_cont(ctx0, Qcur); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - Kcur = ggml_cont(ctx0, Kcur); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); @@ -8151,9 +8163,9 @@ struct llm_build_qwen : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -9126,21 +9138,21 @@ struct llm_build_phi2 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9264,21 +9276,21 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9528,17 +9540,17 @@ struct llm_build_gpt2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, @@ -9634,9 +9646,9 @@ struct llm_build_codeshell : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -10964,8 +10976,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0 cb(all_coefs, "all_coefs", il); - all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] - all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup] + all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] @@ -12378,9 +12390,9 @@ struct llm_build_gptneox : public llm_graph_context { ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13513,17 +13525,17 @@ struct llm_build_jais : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_cont_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, @@ -13626,6 +13638,7 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13635,11 +13648,10 @@ struct llm_build_chatglm : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -13760,6 +13772,7 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -13769,11 +13782,10 @@ struct llm_build_glm4 : public llm_graph_context { } Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -16940,13 +16952,13 @@ private: ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + ggml_tensor * Vcur = ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + Vcur = ggml_cont_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -17013,15 +17025,13 @@ private: cb(zx, "mamba_in_proj", il); // {8192, 5, 1, 1} -> {8192, 1, 5, 1} zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); - zx = ggml_cont(ctx0, zx); - zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); cb(zx, "mamba_in_proj_out", il); // split into z and x // => {head_dim * n_heads, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); - x = ggml_cont(ctx0, x); - x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); // x = ggml_permute(ctx0, x, 0, 2, 1, 3); cb(x, "mamba_x_split", il); diff --git a/src/llama-model.h b/src/llama-model.h index 46f7d0480..f639fa139 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -79,6 +79,7 @@ enum llm_type { LLM_TYPE_40B, LLM_TYPE_65B, LLM_TYPE_70B, + LLM_TYPE_120B, LLM_TYPE_142B, LLM_TYPE_236B, LLM_TYPE_290B, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index c842c738d..e53c30ada 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -533,8 +533,8 @@ struct clip_graph { const int patches_per_image = n_patches_x; const int kernel_size = hparams.proj_scale_factor; - cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); - cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); + cur = ggml_transpose(ctx0, cur); + cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size); // doing a pool2d to reduce the number of output tokens cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); @@ -562,13 +562,13 @@ struct clip_graph { GGML_ASSERT(scale_factor != 0); cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, seq / (scale_factor * scale_factor), bsz); @@ -595,13 +595,13 @@ struct clip_graph { // unshuffle h cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // unshuffle w - cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); // projection cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm @@ -740,15 +740,15 @@ struct clip_graph { auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_add(ctx0, inp, inp_1); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_reshape_4d( + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_cont_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); inp = ggml_reshape_4d( ctx0, inp, n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); - inp = ggml_reshape_3d( + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( ctx0, inp, n_embd, n_patches_x * n_patches_y, batch_size); } @@ -1013,14 +1013,14 @@ struct clip_graph { GGML_ASSERT(scale_factor > 0); cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, cur->ne[1] * cur->ne[2]); } @@ -1106,14 +1106,14 @@ struct clip_graph { n_patches_y, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_4d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches_x / scale_factor, n_patches_y / scale_factor, bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, n_patches / scale_factor / scale_factor); cb(cur, "pixel_shuffle", -1); @@ -1346,8 +1346,8 @@ struct clip_graph { ggml_tensor * block_1 = nullptr; { // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] - mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3)); - mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); + mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); + mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); // stride = 1, padding = 1, bias is nullptr block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); @@ -1452,9 +1452,9 @@ struct clip_graph { mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); // mlp_2 ne = [2048, 576, 1, 1] // // AVG Pool Layer 2*2, strides = 2 - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); // mlp_2 ne = [576, 2048, 1, 1] - mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); // mlp_2 ne [24, 24, 2048, 1] mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); // weight ne = [3, 3, 2048, 1] @@ -1474,8 +1474,8 @@ struct clip_graph { // glm projector else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); - embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); + embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); @@ -2030,7 +2030,6 @@ private: ggml_row_size(cur->type, n_dim), ggml_row_size(cur->type, n_dim*n_head), n_dim/2 * ggml_element_size(cur)); - second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors second = ggml_rope_ext( ctx0, second, @@ -3825,8 +3824,9 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->model.hparams; - // only for models using fixed size square images - int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); + // for models with fixed size image, the input image is already pre-processed and resized to square + int patch_size = params.patch_size; + int n_patches = (img->nx / patch_size) * (img->ny / patch_size); projector_type proj = ctx->proj_type(); @@ -3840,27 +3840,27 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: { - n_patches_sq /= 4; + n_patches /= 4; if (ctx->model.mm_glm_tok_boi) { - n_patches_sq += 2; // for BOI and EOI token embeddings + n_patches += 2; // for BOI and EOI token embeddings } } break; case PROJECTOR_TYPE_MINICPMV: { // Use actual config value if available, otherwise fall back to hardcoded values if (params.minicpmv_query_num > 0) { - n_patches_sq = params.minicpmv_query_num; + n_patches = params.minicpmv_query_num; } else { // Fallback to hardcoded values for legacy models if (params.minicpmv_version == 2) { - n_patches_sq = 96; + n_patches = 96; } else if (params.minicpmv_version == 3) { - n_patches_sq = 64; + n_patches = 64; } else if (params.minicpmv_version == 4) { - n_patches_sq = 64; + n_patches = 64; } else if (params.minicpmv_version == 5) { // MiniCPM-V 4.0 - n_patches_sq = 64; + n_patches = 64; } else { GGML_ABORT("Unknown minicpmv version"); } @@ -3869,67 +3869,56 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: { - // dynamic size + // dynamic size (2 conv, so double patch size) int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); - n_patches_sq = x_patch * y_patch; + n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_GEMMA3: - { - int n_per_side = params.image_size / params.patch_size; - int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; - n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool; - } break; case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_LLAMA4: + case PROJECTOR_TYPE_LFM2: { // both W and H are divided by proj_scale_factor - n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor); + int scale_factor = ctx->model.hparams.proj_scale_factor; + n_patches /= (scale_factor * scale_factor); } break; case PROJECTOR_TYPE_PIXTRAL: { // dynamic size int n_merge = params.spatial_merge_size; - int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); - int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); - n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row - } break; - case PROJECTOR_TYPE_LLAMA4: - { - int scale_factor = ctx->model.hparams.proj_scale_factor; - n_patches_sq /= (scale_factor * scale_factor); + int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); + int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: { - n_patches_sq = img->nx; + n_patches = img->nx; const int proj_stack_factor = ctx->model.hparams.proj_stack_factor; if (ctx->model.audio_has_stack_frames()) { GGML_ASSERT(proj_stack_factor > 0); - const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor); - n_patches_sq = n_len / proj_stack_factor; + const int n_len = CLIP_ALIGN(n_patches, proj_stack_factor); + n_patches = n_len / proj_stack_factor; } // whisper downscales input token by half after conv1d - n_patches_sq /= 2; + n_patches /= 2; if (ctx->model.audio_has_avgpool()) { // divide by 2 because of nn.AvgPool1d(2, stride=2) - n_patches_sq /= 2; + n_patches /= 2; } } break; - case PROJECTOR_TYPE_LFM2: - { - n_patches_sq = (img->nx / (params.patch_size * params.proj_scale_factor)) * (img->ny / (params.patch_size * params.proj_scale_factor)); - } break; default: GGML_ABORT("unsupported projector type"); } - return n_patches_sq; + return n_patches; } static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index e73cf96af..6f8a5f86a 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -68,6 +68,7 @@ add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0" add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" +add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0b40f7bfa..ab88f3d26 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1201,6 +1201,8 @@ struct server_task_result_metrics : server_task_result { uint64_t n_tokens_predicted_total = 0; uint64_t t_tokens_generation_total = 0; + uint64_t n_past_max = 0; + uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; @@ -1226,6 +1228,8 @@ struct server_task_result_metrics : server_task_result { { "n_tokens_predicted_total", n_tokens_predicted_total }, { "t_prompt_processing_total", t_prompt_processing_total }, + { "n_past_max", n_past_max }, + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, { "t_prompt_processing", t_prompt_processing }, { "n_tokens_predicted", n_tokens_predicted }, @@ -1587,6 +1591,8 @@ struct server_metrics { uint64_t n_tokens_predicted_total = 0; uint64_t t_tokens_generation_total = 0; + uint64_t n_past_max = 0; + uint64_t n_prompt_tokens_processed = 0; uint64_t t_prompt_processing = 0; @@ -1605,6 +1611,10 @@ struct server_metrics { n_prompt_tokens_processed += slot.n_prompt_tokens_processed; t_prompt_processing += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing; + + if (slot.n_past > 0) { + n_past_max = std::max(n_past_max, (uint64_t) slot.n_past); + } } void on_prediction(const server_slot & slot) { @@ -1620,6 +1630,9 @@ struct server_metrics { if (slot.is_processing()) { n_busy_slots_total++; } + if (slot.n_past > 0) { + n_past_max = std::max(n_past_max, (uint64_t) slot.n_past); + } } } @@ -1716,7 +1729,7 @@ struct server_queue { void pop_deferred_task() { std::unique_lock lock(mutex_tasks); if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); queue_tasks_deferred.pop_front(); } condition_tasks.notify_one(); @@ -2875,6 +2888,8 @@ struct server_context { res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; res->t_tokens_generation_total = metrics.t_tokens_generation_total; + res->n_past_max = metrics.n_past_max; + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; res->t_prompt_processing = metrics.t_prompt_processing; res->n_tokens_predicted = metrics.n_tokens_predicted; @@ -4077,6 +4092,10 @@ int main(int argc, char ** argv) { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, {"value", res_metrics->n_decode_total} + }, { + {"name", "n_past_max"}, + {"help", "Largest observed n_past."}, + {"value", res_metrics->n_past_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 1485de8ce..c7b3af048 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,7 +5,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index be3a0052c..adb6f2786 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -7,7 +7,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -229,7 +229,7 @@ def test_nocache_long_input_prompt(): "temperature": 1.0, "cache_prompt": False, }) - assert res.status_code == 200 + assert res.status_code == 400 def test_completion_with_tokens_input(): diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 2431ac708..8f51bc301 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -11,7 +11,7 @@ Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. """.strip() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @@ -25,6 +25,7 @@ def test_ctx_shift_enabled(): # the prompt is truncated to keep the last 109 tokens # 64 tokens are generated thanks to shifting the context when it gets full global server + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -42,7 +43,6 @@ def test_ctx_shift_enabled(): ]) def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): global server - server.disable_ctx_shift = True server.n_predict = -1 server.start() res = server.make_request("POST", "/completion", data={ @@ -56,7 +56,6 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr def test_ctx_shift_disabled_long_prompt(): global server - server.disable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 64, @@ -68,7 +67,6 @@ def test_ctx_shift_disabled_long_prompt(): def test_ctx_shift_disabled_stream(): global server - server.disable_ctx_shift = True server.start() res = server.make_stream_request("POST", "/v1/completions", data={ "n_predict": 256, diff --git a/tools/server/tests/unit/test_embedding.py b/tools/server/tests/unit/test_embedding.py index 0feb452cc..50601b839 100644 --- a/tools/server/tests/unit/test_embedding.py +++ b/tools/server/tests/unit/test_embedding.py @@ -8,7 +8,7 @@ server = ServerPreset.bert_bge_small() EPSILON = 1e-3 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.bert_bge_small() diff --git a/tools/server/tests/unit/test_infill.py b/tools/server/tests/unit/test_infill.py index 10554db0f..73dacdae8 100644 --- a/tools/server/tests/unit/test_infill.py +++ b/tools/server/tests/unit/test_infill.py @@ -3,7 +3,7 @@ from utils import * server = ServerPreset.tinyllama_infill() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama_infill() diff --git a/tools/server/tests/unit/test_lora.py b/tools/server/tests/unit/test_lora.py index c1aa8be70..00b2f245f 100644 --- a/tools/server/tests/unit/test_lora.py +++ b/tools/server/tests/unit/test_lora.py @@ -5,7 +5,7 @@ server = ServerPreset.stories15m_moe() LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.stories15m_moe() diff --git a/tools/server/tests/unit/test_rerank.py b/tools/server/tests/unit/test_rerank.py index f4f570ad5..0b63c7821 100644 --- a/tools/server/tests/unit/test_rerank.py +++ b/tools/server/tests/unit/test_rerank.py @@ -4,7 +4,7 @@ from utils import * server = ServerPreset.jina_reranker_tiny() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.jina_reranker_tiny() diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 620b25376..0e1158055 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -6,7 +6,7 @@ server = ServerPreset.tinyllama2() TEST_API_KEY = "sk-this-is-the-secret-key" -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_slot_save.py b/tools/server/tests/unit/test_slot_save.py index 38704f5ec..1b428cc2a 100644 --- a/tools/server/tests/unit/test_slot_save.py +++ b/tools/server/tests/unit/test_slot_save.py @@ -3,7 +3,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_speculative.py b/tools/server/tests/unit/test_speculative.py index 54db38cf3..38ca4325b 100644 --- a/tools/server/tests/unit/test_speculative.py +++ b/tools/server/tests/unit/test_speculative.py @@ -16,7 +16,7 @@ def create_server(): server.draft_max = 8 -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def fixture_create_server(): return create_server() @@ -91,6 +91,7 @@ def test_slot_ctx_not_exceeded(): def test_with_ctx_shift(): global server server.n_ctx = 64 + server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "Hello " * 56, diff --git a/tools/server/tests/unit/test_tokenize.py b/tools/server/tests/unit/test_tokenize.py index 382457c9d..424cac5f3 100644 --- a/tools/server/tests/unit/test_tokenize.py +++ b/tools/server/tests/unit/test_tokenize.py @@ -4,7 +4,7 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 20f048c6f..a3c3ccdf5 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -22,6 +22,8 @@ def create_server(): server.model_alias = "tinyllama-2-tool-call" server.server_port = 8081 server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 class CompletionMode(Enum): NORMAL = "normal" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index bc547ca03..49277e600 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -79,7 +79,7 @@ class ServerProcess: draft: int | None = None api_key: str | None = None lora_files: List[str] | None = None - disable_ctx_shift: int | None = False + enable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None @@ -178,8 +178,8 @@ class ServerProcess: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.disable_ctx_shift: - server_args.extend(["--no-context-shift"]) + if self.enable_ctx_shift: + server_args.append("--context-shift") if self.api_key: server_args.extend(["--api-key", self.api_key]) if self.draft_max: diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index a71e9bf5b..18f01a994 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -581,7 +581,6 @@ int main(int argc, char ** argv) { params.model = params.vocoder.model; params.embedding = true; - params.ctx_shift = false; // silence warning params.n_ubatch = params.n_batch; common_init_result llama_init_cts = common_init_from_params(params);