diff --git a/common/arg.cpp b/common/arg.cpp index 0614f305c..ba184c700 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2871,6 +2871,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "(default: deepseek)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } else { throw std::invalid_argument("invalid value"); } } diff --git a/common/chat.cpp b/common/chat.cpp index 9dfa640b2..846d7ab45 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -82,10 +82,10 @@ json common_chat_msg::to_json_oaicompat() const std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { std::vector diffs; - // if (previous_msg.reasoning_content != current.reasoning_content) { - // auto & diff = diffs.emplace_back(); - // diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content); - // } + if (previous_msg.reasoning_content != new_msg.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + } if (previous_msg.content != new_msg.content) { auto & diff = diffs.emplace_back(); diff.content_delta = string_diff(previous_msg.content, new_msg.content); @@ -385,9 +385,9 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json delta = json::object(); - // if (!diff.reasoning_content_delta.empty()) { - // delta["reasoning_content"] = msg.reasoning_content; - // } + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } if (!diff.content_delta.empty()) { delta["content"] = diff.content_delta; } @@ -598,6 +598,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { case COMMON_REASONING_FORMAT_NONE: return "none"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; default: throw std::runtime_error("Unknown reasoning format"); } diff --git a/common/chat.h b/common/chat.h index f6b1d0ffc..9f59e6b08 100644 --- a/common/chat.h +++ b/common/chat.h @@ -70,7 +70,7 @@ struct common_chat_msg { }; struct common_chat_msg_diff { - // std::string reasoning_content_delta; + std::string reasoning_content_delta; std::string content_delta; size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; diff --git a/common/common.h b/common/common.h index aad27f500..c1795831a 100644 --- a/common/common.h +++ b/common/common.h @@ -211,7 +211,8 @@ struct common_params_vocoder { enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` + COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d8de7531b..08facb6d0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8132,8 +8132,8 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define WKV_VECTOR_SIZE 4 #endif - int wkv_vector_size; #ifdef WKV_VECTOR_SIZE + int wkv_vector_size; #if defined(__ARM_FEATURE_SVE) wkv_vector_size = svcntw(); #else @@ -8348,8 +8348,8 @@ static void ggml_compute_forward_gla_f32( #define GLA_VECTOR_SIZE 4 #endif - int gla_vector_size; #ifdef GLA_VECTOR_SIZE + int gla_vector_size; #if defined(__ARM_FEATURE_SVE) gla_vector_size = svcntw(); #else diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 91440fefd..22f9b9c5b 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + const float KQ_max_diff = KQ_max[col] - KQ_max_new[col]; + KQ_max_scale[col] = expf(KQ_max_diff); KQ_max[col] = KQ_max_new[col]; + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c100787e2..7dbcdb060 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; + // 2*(2*ncpsg + nqptg)*(nsg) // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) // @@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node( // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; @@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node( // and store the soft_max values and the mask // // ne00*(nsg) - // each simdgroup has a full f16 head vector in shared mem to accumulate results + // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59899550e..58763e39e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext( constexpr short NW = N_SIMDWIDTH; constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = DK + 2*TS; // shared memory size per query in (half) + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = 2*DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < args.ne01) { sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*DK4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = 0; } } } @@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - float S = { 0.0f }; - float M = { -__FLT_MAX__/2 }; - threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq @@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext( const float M0 = ss[j*TS + 1]; const float M1 = ss[j*TS + sg*SH + 1]; - M = max(M0, M1); + const float M = max(M0, M1); const float ms0 = exp(M0 - M); const float ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[j*TS + 0] = S; @@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext( } } - device float4 * dst4 = (device float4 *) dst; + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { - const float S = ss[j*TS + 0]; + for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { + const float S = 1.0f/sf[j*TS + 0]; - for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; - } + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*DV4 + i]*S; } } } @@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext( // template to be able to explore different combinations // #define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 + float, float4, simdgroup_float8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 + +#define FA_TYPES_BF \ + bfloat, bfloat4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; @@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES +#undef FA_TYPES_BF template< typename q4_t, // query types in shared memory @@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec( const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results // store the result for all queries in local memory (the O matrix from the paper) o4_t lo[DV4/NL]; @@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec( half4, \ float, \ float, float4, \ - half4 + float4 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl new file mode 100644 index 000000000..132758469 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -0,0 +1,109 @@ +kernel void kernel_concat_f32_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice + int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) + int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice + int dim +) { + global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); + global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); + global float * dst = (global float*)((global char*)p_dst + off_dst); + + int i0 = get_global_id(0); // Index along dst's 0th dimension + int i1 = get_global_id(1); // Index along dst's 1st dimension + int i2 = get_global_id(2); // Index along dst's 2nd dimension + + if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { + return; + } + + ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; + ulong src_idx; + + if (dim == 0) { + if (i0 < d_ne00) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 1) { + if (i1 < d_ne01) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 2) { + if (i2 < d_ne02) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + + src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } +} + +kernel void kernel_concat_f32_non_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + + long ne00, long ne01, long ne02, long ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + + ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + + long d_ne0, long d_ne1, long d_ne2, long d_ne3, + ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, + int dim +) { + global const char * src0_base = p_src0 + off_src0; + global const char * src1_base = p_src1 + off_src1; + global char * dst_base = p_dst + off_dst; + + long current_i1 = get_global_id(0); // Index for dst_dim_1 + long current_i2 = get_global_id(1); // Index for dst_dim_2 + long current_i3 = get_global_id(2); // Index for dst_dim_3 + + if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { + return; + } + + global const float * x_val_ptr; + global float * y_val_ptr; + + for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { + bool use_src0; + long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; + + if (dim == 0) { + use_src0 = (current_i0 < ne00); + if (!use_src0) { s_i0 = current_i0 - ne00; } + } else if (dim == 1) { + use_src0 = (current_i1 < ne01); + if (!use_src0) { s_i1 = current_i1 - ne01; } + } else if (dim == 2) { + use_src0 = (current_i2 < ne02); + if (!use_src0) { s_i2 = current_i2 - ne02; } + } else { // dim == 3 + use_src0 = (current_i3 < ne03); + if (!use_src0) { s_i3 = current_i3 - ne03; } + } + + if (use_src0) { + x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + } else { + x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + } + + y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); + *y_val_ptr = *x_val_ptr; + } +} diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl new file mode 100644 index 000000000..747fa7feb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -0,0 +1,30 @@ +kernel void kernel_pad( + global const void * src0_ptr, + ulong src0_offset, + global void * dst_ptr, + ulong dst_offset, + int s_ne0, int s_ne1, int s_ne2, + int d_ne0, int d_ne1, int d_ne2 +) { + global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); + global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + + int nidx = get_global_id(0); + int idx_d1 = get_group_id(1); + int idx_d2 = get_group_id(2); + + if (nidx >= d_ne0) { + return; + } + + int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + + bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + + if (in_src_bounds) { + int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; + dst[dst_el_offset] = src0[src_el_offset]; + } else { + dst[dst_el_offset] = 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl new file mode 100644 index 000000000..079498f5a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -0,0 +1,39 @@ +kernel void kernel_repeat( + global const char * src0_data_in, + global char * dst_data_in, + ulong src0_offset, + ulong dst_offset, + int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, + ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, + int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, + ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +) { + global const char * src0_data = src0_data_in + src0_offset; + global char * dst_data = dst_data_in + dst_offset; + + const int d3 = get_global_id(2); + const int d2 = get_global_id(1); + const int d1 = get_global_id(0); + + if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { + return; + } + + const int s3 = d3 % src0_ne3; + const int s2 = d2 % src0_ne2; + const int s1 = d1 % src0_ne1; + + const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; + global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + + for (int d0 = 0; d0 < dst_ne0; ++d0) { + // Determine source index for dimension 0 based on tiling/broadcasting. + const int s0 = d0 % src0_ne0; + + const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; + global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; + for (int k = 0; k < src0_nb0; ++k) { + current_dst_el_ptr[k] = current_src_el_ptr[k]; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl new file mode 100644 index 000000000..d9da86b14 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -0,0 +1,63 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +kernel void kernel_tanh_f32_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} + +kernel void kernel_tanh_f16_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/tsembd.cl b/ggml/src/ggml-opencl/kernels/tsembd.cl new file mode 100644 index 000000000..4b1107f70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tsembd.cl @@ -0,0 +1,48 @@ +kernel void kernel_timestep_embedding( + global const void * p_timesteps, + ulong off_timesteps, + global void * p_dst, + ulong off_dst, + int dst_nb1_bytes, + int logical_dim, + int max_period +) { + int local_i; + int local_j; + int local_half_dim; + float local_timestep_val; + float local_freq; + float local_arg; + global float * local_embed_data_ptr; + global const float * local_timesteps_input_ptr; + global float * local_dst_output_base_ptr; + + local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps); + local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst); + + local_i = get_global_id(1); + local_j = get_global_id(0); + + local_half_dim = logical_dim / 2; + local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes); + + if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) { + local_embed_data_ptr[logical_dim] = 0.0f; + } + + if (local_j >= local_half_dim) { + return; + } + + local_timestep_val = local_timesteps_input_ptr[local_i]; + + if (local_half_dim == 0) { + local_freq = 1.0f; + } else { + local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim); + } + + local_arg = local_timestep_val * local_freq; + local_embed_data_ptr[local_j] = cos(local_arg); + local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg); +} diff --git a/ggml/src/ggml-opencl/kernels/upscale.cl b/ggml/src/ggml-opencl/kernels/upscale.cl new file mode 100644 index 000000000..219d31dbb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/upscale.cl @@ -0,0 +1,121 @@ +kernel void kernel_upscale( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10 * ne11 * ne12 * ne13; + + if (index >= dst_total_elements) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = index / (ne10 * ne11 * ne12); + + int i00 = (int)(i10 / sf0); + int i01 = (int)(i11 / sf1); + int i02 = (int)(i12 / sf2); + int i03 = (int)(i13 / sf3); + + ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00; + global const float * src_element_ptr = (global const float *)(src_base + offset_src_element); + + dst_base[index] = *src_element_ptr; +} + +kernel void kernel_upscale_bilinear( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne00_src, + int ne01_src, + int ne10_dst, + int ne11_dst, + int ne12_dst, + int ne13_dst, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + int i10_dst = index % ne10_dst; + int i11_dst = (index / ne10_dst) % ne11_dst; + int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + int i02_src = (int)(i12_dst / sf2); + int i03_src = (int)(i13_dst / sf3); + + const float pixel_offset = 0.5f; + + float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + long y0_src = (long)floor(y_src_f); + long y1_src = y0_src + 1; + + y0_src = max(0L, min(y0_src, (long)ne01_src - 1)); + y1_src = max(0L, min(y1_src, (long)ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + long x0_src = (long)floor(x_src_f); + long x1_src = x0_src + 1; + + x0_src = max(0L, min(x0_src, (long)ne00_src - 1)); + x1_src = max(0L, min(x1_src, (long)ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst_base[index] = result; +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 564fe7d37..633f0f3c3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -412,6 +412,7 @@ struct vk_device_struct { vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -460,7 +461,7 @@ struct vk_device_struct { // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; vk::QueryPool query_pool; - uint32_t num_queries; + int32_t num_queries; ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -722,6 +723,21 @@ struct vk_op_timestep_embedding_push_constants { uint32_t max_period; }; +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -2742,6 +2758,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -6416,6 +6434,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_timestep_embedding_f32; } return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; case GGML_OP_POOL_2D: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_pool2d_f32; @@ -6750,6 +6773,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co uint32_t half_ceil = (dim + 1) / 2; elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; case GGML_OP_POOL_2D: { const uint32_t N = dst->ne[3]; @@ -7553,6 +7580,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context }, dryrun); } +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -8624,6 +8682,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: @@ -8688,6 +8747,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: @@ -8859,6 +8919,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); @@ -8987,6 +9051,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: @@ -9537,8 +9602,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (ctx->device->query_pool) { ctx->device->device.destroyQueryPool(ctx->device->query_pool); } - VkQueryPoolCreateInfo query_create_info = { VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO }; - query_create_info.queryType = VK_QUERY_TYPE_TIMESTAMP; + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; query_create_info.queryCount = cgraph->n_nodes + 100; ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); ctx->device->num_queries = query_create_info.queryCount; @@ -9624,7 +9689,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); - ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); for (int i = 0; i < cgraph->n_nodes; i++) { if (!ggml_vk_is_empty(cgraph->nodes[i])) { ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); @@ -10048,6 +10113,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; default: return false; } @@ -10539,6 +10606,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 000000000..b17b4e83e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f169fef77..79a110593 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -636,6 +636,8 @@ void process_shaders() { string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index db9a56208..8a571d719 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const { return kv_self; } -bool llama_context::kv_self_update() { +void llama_context::kv_self_defrag_sched() { + if (!memory) { + return; + } + + memory_force_optimize = true; +} + +bool llama_context::kv_self_update(bool optimize) { if (!memory) { return false; } llama_kv_cache * kv_self = static_cast(memory.get()); - if (!kv_self->update(*this)) { - // no updates have been performed - return false; + { + // TODO: remove in the future + optimize |= memory_force_optimize; + memory_force_optimize = false; + + const auto kv_state = kv_self->init_update(this, optimize); + switch (kv_state->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + // noop + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + // no updates need to be performed + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); + return false; + } + } + + if (!kv_state->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); + } } // if the KV cache did any computation, we have to reserve a new worst-case graph const auto kv_state = kv_self->init_full(); if (!kv_state) { - throw std::runtime_error("failed to initialize KV cache"); + throw std::runtime_error("failed to initialize memory state"); } const uint32_t n_seqs = cparams.n_seq_max; @@ -452,7 +484,7 @@ bool llama_context::kv_self_update() { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); if (!gf) { - LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } return true; @@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } + bool did_optimize = false; + // handle any pending defrags/shifts - kv_self_update(); + kv_self_update(false); llama_memory_state_ptr kv_state; - bool did_defrag = false; - while (true) { kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); if (!kv_state) { @@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); + + return -2; + } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { - if (!did_defrag) { - did_defrag = true; + if (!did_optimize) { + did_optimize = true; - kv_self->defrag_sched(-1.0f); - if (kv_self_update()) { - LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); continue; } } - LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + return -2; } } @@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // decide if we need to defrag the kv cache - if (cparams.defrag_thold > 0.0f) { - kv_self->defrag_sched(cparams.defrag_thold); - } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. ggml_backend_sched_reset(sched.get()); @@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) { // deprecated void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(); + ctx->kv_self_update(false); } enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { @@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { // deprecated void llama_kv_self_defrag(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); - if (!kv) { - return; - } - // force defrag - kv->defrag_sched(-1.0f); + ctx->kv_self_defrag_sched(); } bool llama_kv_self_can_shift(const llama_context * ctx) { diff --git a/src/llama-context.h b/src/llama-context.h index 3b880286b..c1c7efb31 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -52,7 +52,8 @@ struct llama_context { // return true of the KV cache was updated // TODO: remove - bool kv_self_update(); + bool kv_self_update(bool optimize); + void kv_self_defrag_sched(); enum llama_pooling_type pooling_type() const; @@ -231,6 +232,9 @@ private: std::unique_ptr memory; + // TODO: temporary, until the llama_kv_self_defrag() API is removed + bool memory_force_optimize = false; + // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 7e4cefb00..50fe0c44d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -769,9 +769,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { - // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) - ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); - repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] + // repeat cur to [n_embd, n_expert_used, n_tokens] + ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); cur = ggml_mul(ctx0, repeated, weights); cb(cur, "ffn_moe_weighted", il); } diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index e6173eb6f..7cfb3ea05 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-recurrent.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-batch.h" #include "llama-model.h" @@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } +llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); +} + bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { // simply remember the full state because it is very small for this type of cache // TODO: optimize @@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche return success; } -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); - // noop - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index a178ae85c..b32f258fb 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -52,9 +52,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool prepare(const std::vector & ubatches); diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 0eb045634..3aa606c84 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch assert(heads_base.size() == heads_swa.size()); - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = false; - - res = res | kv_base->update(lctx); - res = res | kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_unified_iswa::get_can_shift() const { @@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv) : status(status) { - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); + llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_full(); + state_swa = kv->get_swa ()->init_full(); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_update(lctx, optimize); + state_swa = kv->get_swa ()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); } llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, std::vector heads_swa, std::vector ubatches) - : status(status), - sbatch(std::move(sbatch)), - ubatches(std::move(ubatches)) { - // note: here we copy the ubatches. not sure if this is ideal - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); - } + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; @@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_base.get(); + return static_cast(state_base.get()); } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_swa.get(); + return static_cast(state_swa.get()); } diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 8b067da03..cba5bbe95 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -54,9 +54,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -86,12 +84,16 @@ public: // used to create a full-cache state llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv); + // used to create an update state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize); + // used to create a state from a batch llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, @@ -120,7 +122,7 @@ public: const llama_kv_cache_unified_state * get_swa() const; private: - const llama_memory_status status; + llama_memory_status status; //llama_kv_cache_unified_iswa * kv; @@ -131,6 +133,6 @@ private: std::vector ubatches; - std::unique_ptr state_base; - std::unique_ptr state_swa; + llama_memory_state_ptr state_base; + llama_memory_state_ptr state_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index a2ec8d222..d26c0e1d7 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-unified.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-model.h" #include "llama-context.h" @@ -149,12 +150,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + if (seq_id >= 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } else { + // match any sequence + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); - if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { if (new_head == cells.size()) { new_head = i; } @@ -305,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { - std::vector res; +llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { + bool do_shift = get_has_shift(); + + defrag_info dinfo; + + // see if we need to defrag + { + bool do_defrag = optimize; + + const auto thold = lctx->get_cparams().defrag_thold; + + if (!do_defrag && thold > 0.0f) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } + } + + if (do_defrag) { + dinfo = defrag_prepare(lctx->graph_max_nodes()); + } + } + + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); +} + +llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::ubatch_heads res; struct state { uint32_t head_old; // old position of the head, before placing the ubatch @@ -359,12 +408,12 @@ std::vector llama_kv_cache_unified::prepare(const std::vectorget_sched(); - if (cells.get_has_shift()) { + if (do_shift) { if (!get_can_shift()) { printf("\nWARNING: The current KV cache / model configuration does not support K-shift"); } else { @@ -375,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched); - auto * gf = lctx.graph_init(); + auto * gf = lctx->graph_init(); - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); return updated; @@ -390,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { res->set_inputs(nullptr); - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); return updated; } @@ -401,56 +450,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { cells.reset_shift(); }} - if (do_defrag) { + if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); + // apply moves: + { + const auto n_kv = dinfo.ids.size(); - auto * gf = lctx.graph_init(); + for (uint32_t i = 0; i < n_kv; ++i) { + assert(dinfo.ids[i] <= n_kv); - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; + if (dinfo.ids[i] == n_kv) { + continue; + } + + cells.mv(i, dinfo.ids[i]); } - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; + // reset the head so we can find the first free slot during the next ubatch + head = 0; } - do_defrag = false; + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; } return updated; } -void llama_kv_cache_unified::defrag_sched(float thold) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -597,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const { return cells.size(); } +bool llama_kv_cache_unified::get_has_shift() const { + return cells.get_has_shift(); +} + uint32_t llama_kv_cache_unified::get_n_kv() const { return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); } @@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const { auto res = std::make_unique(); - const auto & ids = defrag_info.ids; + const auto & ids = dinfo.ids; #if 0 // CPU defrag @@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( return res; } -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { +llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; + defrag_info res; + auto & ids = res.ids; - ids.clear(); ids.resize(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { @@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // this cell goes to (i0 + nf) ids[i1] = i0 + nf; - // move the cell meta data - cells.mv(i1, i0 + nf); - - head = n_used; - if (!cont) { n_moves++; cont = true; @@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { } if (n_moves == 0) { - return false; + return {}; } LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - return true; + return res; } bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { @@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv) : status(status), kv(kv) { - n_kv = kv->get_size(); - head = 0; - } + llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { + n_kv = kv->get_size(); + head = 0; +} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches) - : status(status), - kv(kv), - sbatch(std::move(sbatch)), - heads(std::move(heads)), - ubatches(std::move(ubatches)) { + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { + if (!do_shift && dinfo.empty()) { + status = LLAMA_MEMORY_STATUS_NO_UPDATE; } +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + llama_kv_cache_unified::ubatch_heads heads, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { +} llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; @@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() { bool llama_kv_cache_unified_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + // no ubatches -> this is a KV cache update + if (ubatches.empty()) { + kv->update(lctx, do_shift, dinfo); + + return true; + } + kv->apply_ubatch(heads[i_next], ubatches[i_next]); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 1f1d44b97..6ff388a88 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -24,6 +24,19 @@ public: // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; + using ubatch_heads = std::vector; + + struct defrag_info { + bool empty() const { + return ids.empty(); + } + + // contains information about which cell moves where: + // - cell i moves to ids[i] + // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved + std::vector ids; + }; + llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -66,9 +79,7 @@ public: llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; @@ -83,6 +94,8 @@ public: uint32_t get_size() const; + bool get_has_shift() const; + // // graph_build API // @@ -103,7 +116,9 @@ public: // find places for the provided ubatches in the cache, returns the head locations // return empty vector on failure - std::vector prepare(const std::vector & ubatches); + ubatch_heads prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); // return the cell position where we can insert the ubatch // return -1 on failure to find a contiguous slot of kv cells @@ -133,8 +148,7 @@ private: ggml_tensor * v; }; - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed + bool v_trans = true; // the value tensor is transposed // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method @@ -160,13 +174,8 @@ private: // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); + // return non-empty vector if cells have been moved + defrag_info defrag_prepare(int32_t n_max_nodes) const; size_t total_size() const; @@ -192,7 +201,8 @@ private: llm_graph_result_ptr build_graph_defrag( const llama_cparams & cparams, ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * gf, + const defrag_info & dinfo) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -203,20 +213,29 @@ private: class llama_kv_cache_unified_state : public llama_memory_state_i { public: + // some shorthands + using ubatch_heads = llama_kv_cache_unified::ubatch_heads; + using defrag_info = llama_kv_cache_unified::defrag_info; + // used for errors llama_kv_cache_unified_state(llama_memory_status status); // used to create a full-cache state llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv); - // used to create a state from a batch + // used to create an update state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo); + + // used to create a decode state from a batch llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv, llama_sbatch sbatch, - std::vector heads, + ubatch_heads heads, std::vector ubatches); virtual ~llama_kv_cache_unified_state(); @@ -253,16 +272,30 @@ public: void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: - const llama_memory_status status; + llama_memory_status status; llama_kv_cache_unified * kv; + llama_context * lctx; + + // + // update state + // + + bool do_shift = false; + + defrag_info dinfo; + + // + // batch processing state + // llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads; + ubatch_heads heads; + std::vector ubatches; // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 2d04705f2..17a5e5cb8 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -1,12 +1,16 @@ #pragma once #include "llama.h" -#include "llama-io.h" #include "llama-memory.h" +class llama_io_write_i; +class llama_io_read_i; + struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; + // TODO: move the init_ interfaces to llama_memory_i + // split the input batch into a set of ubatches and verify that they can fit into the cache // return a state object containing the ubatches and KV cache state required to process them // check the llama_memory_state_i::get_status() for the result @@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_state_ptr init_full() = 0; - // process any pending defrag/shift/etc. operations - // optionally call once before processing a new batch - // return true if any operations were performed - virtual bool update(llama_context & lctx) = 0; - - // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing - // TODO: change to - // llama_memory_state_ptr init_defrag(float thold) = 0; - // - virtual void defrag_sched(float thold) = 0; + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; // getters virtual bool get_can_shift() const = 0; diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index 10173253e..f1107672c 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -1 +1,42 @@ #include "llama-memory.h" + +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { + bool has_update = false; + + switch (s0) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s0; + } + } + + switch (s1) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s1; + } + } + + // if either status has an update, then the combined status has an update + return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index b3799d66e..ab0d399c4 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -36,12 +36,19 @@ public: virtual bool get_can_edit() const = 0; }; +using llama_memory_ptr = std::unique_ptr; + enum llama_memory_status { LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_NO_UPDATE, LLAMA_MEMORY_STATUS_FAILED_PREPARE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; +// helper function for combining the status of two memory states +// useful for implementing hybrid memory types (e.g. iSWA) +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); + // the interface for managing the memory state during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_state @@ -69,7 +76,7 @@ public: // get the current ubatch virtual const llama_ubatch & get_ubatch() const = 0; - // get the status of the memory state + // get the status of the memory state - used for error handling and checking if any updates would be applied virtual llama_memory_status get_status() const = 0; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 65015cabd..65fe79b68 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -961,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 46: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; case LLM_ARCH_GEMMA3: { @@ -981,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); @@ -8584,14 +8590,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - switch (model.type) { - case LLM_TYPE_2B: - case LLM_TYPE_9B: - case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break; - default: GGML_ABORT("fatal error"); - }; - cb(Qcur, "Qcur_scaled", il); + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, @@ -8732,9 +8731,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 508a64c58..40deab5ab 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -70,6 +70,7 @@ struct mtmd_cli_context { llama_model * model; llama_context * lctx; const llama_vocab * vocab; + common_sampler * smpl; llama_batch batch; int n_batch; @@ -89,8 +90,9 @@ struct mtmd_cli_context { model = llama_init.model.get(); lctx = llama_init.context.get(); vocab = llama_model_get_vocab(model); + smpl = common_sampler_init(model, params.sampling); n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(params.n_batch, 0, 1); + batch = llama_batch_init(1, 0, 1); // batch for next token generation n_batch = params.n_batch; if (!model || !lctx) { @@ -118,6 +120,11 @@ struct mtmd_cli_context { } } + ~mtmd_cli_context() { + llama_batch_free(batch); + common_sampler_free(smpl); + } + void init_vision_context(common_params & params) { const char * clip_path = params.mmproj.path.c_str(); mtmd_context_params mparams = mtmd_context_params_default(); @@ -153,7 +160,7 @@ struct mtmd_cli_context { } }; -static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) { +static int generate_response(mtmd_cli_context & ctx, int n_predict) { llama_tokens generated_tokens; for (int i = 0; i < n_predict; i++) { if (i > n_predict || !g_is_generating || g_is_interrupted) { @@ -161,9 +168,9 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int break; } - llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); + llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1); generated_tokens.push_back(token_id); - common_sampler_accept(smpl, token_id, true); + common_sampler_accept(ctx.smpl, token_id, true); if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) { LOG("\n"); @@ -261,7 +268,6 @@ int main(int argc, char ** argv) { bool is_single_turn = !params.prompt.empty() && !params.image.empty(); - struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; // Ctrl+C handling @@ -300,7 +306,7 @@ int main(int argc, char ** argv) { if (eval_message(ctx, msg, true)) { return 1; } - if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) { + if (!g_is_interrupted && generate_response(ctx, n_predict)) { return 1; } @@ -366,7 +372,7 @@ int main(int argc, char ** argv) { return 1; } if (g_is_interrupted) break; - if (generate_response(ctx, smpl, n_predict)) { + if (generate_response(ctx, n_predict)) { return 1; } content.clear(); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 64f03fd1e..686f42f39 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -311,6 +311,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, GGML_ABORT("chunk type not supported"); } + llama_batch_free(text_batch); return 0; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4b92eeac9..9038df4c3 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -360,7 +360,7 @@ struct server_task { params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; } params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream; + params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (params_base.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } @@ -2016,6 +2016,11 @@ struct server_context { params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by this context"); + return false; + } } return true; @@ -3203,9 +3208,7 @@ struct server_context { } } else { // if we don't cache the prompt, we have to remove the entire KV cache - llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; - slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()" } if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { @@ -3220,7 +3223,6 @@ struct server_context { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; } } diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index 610610749..20f048c6f 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -499,13 +499,12 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr @pytest.mark.slow -@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [ - (128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^I need to calculate [\\s\\S]*?To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), - (1024, 'deepseek', CompletionMode.STREAMED, None, "^First, I [\\s\\S]*?To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), +@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) +@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [ + (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), ]) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index f7e1b3b3b..bc547ca03 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -308,10 +308,12 @@ class ServerProcess: stream = data.get('stream', False) if stream: content: list[str] = [] + reasoning_content: list[str] = [] tool_calls: list[dict] = [] finish_reason: Optional[str] = None content_parts = 0 + reasoning_content_parts = 0 tool_call_parts = 0 arguments_parts = 0 @@ -322,6 +324,10 @@ class ServerProcess: assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' content.append(choice['delta']['content']) content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 if choice['delta'].get('finish_reason') is not None: finish_reason = choice['delta']['finish_reason'] for tc in choice['delta'].get('tool_calls', []): @@ -349,8 +355,10 @@ class ServerProcess: tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] if fct.get('arguments') is not None: tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 - print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') result = dict( choices=[ dict( @@ -359,6 +367,7 @@ class ServerProcess: message=dict( role='assistant', content=''.join(content) if content else None, + reasoning_content=''.join(reasoning_content) if reasoning_content else None, tool_calls=tool_calls if tool_calls else None, ), )