diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7b9893c8a..173a103ba 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -556,8 +556,11 @@ class TextModel(ModelBase): logger.info(f"gguf: experts used count = {n_experts_used}") if (head_dim := self.hparams.get("head_dim")) is not None: - self.gguf_writer.add_key_length(head_dim) - self.gguf_writer.add_value_length(head_dim) + # Workaround for incorrect AutoConfig value for DeepSeekV3 (is set correctly in DeepSeekV2Model class) + # https://github.com/huggingface/transformers/blob/19224c3642705c5b6988c9f5f4251f83323d05ae/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py#L210 + if self.hparams.get("model_type") != "deepseek_v3": + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") @@ -4798,25 +4801,6 @@ class OlmoeModel(TextModel): class JinaBertV2Model(BertModel): model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.intermediate_size = self.hparams["intermediate_size"] - - def get_tensors(self): - for name, data in super().get_tensors(): - if 'gated_layer' in name: - d1 = data[:self.intermediate_size, :] - name1 = name.replace('gated_layers', 'gated_layers_w') - name1 = name1.replace('up_gated_layer', 'gated_layers_v') - d2 = data[self.intermediate_size:, :] - name2 = name.replace('gated_layers', 'gated_layers_v') - name2 = name2.replace('up_gated_layer', 'gated_layers_w') - yield name1, d1 - yield name2, d2 - continue - - yield name, data - def set_vocab(self): tokenizer_class = 'BertTokenizer' with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: @@ -4832,14 +4816,6 @@ class JinaBertV2Model(BertModel): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # if name starts with "bert.", remove the prefix - # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - if name.startswith("bert."): - name = name[5:] - - return super().modify_tensors(data_torch, name, bid) - @ModelBase.register("OpenELMForCausalLM") class OpenELMModel(TextModel): diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 337d8094e..69415daa8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -518,11 +518,14 @@ void ggml_barrier(struct ggml_threadpool * tp); #elif defined(__GNUC__) // GCC/Clang on *nix # define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT -#elif defined(_MSC_VER) && defined (_WIN64) +#elif defined(_MSC_VER) && defined(_WIN64) // MSVC // Note: C name mangling varies across different calling conventions // see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170 # define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias)) +#elif defined(_MSC_VER) && defined(WIN32) +// ref: https://github.com/ggml-org/whisper.cpp/pull/3239#issuecomment-2958224591 +# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:_" #name "=_" #alias)) #else # error "Unsupported compiler for GGML_WEAK_ALIAS" #endif diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 58763e39e..5d7760217 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext( threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory @@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext( // O = diag(ms)*O { - s8x8_t mm; - simdgroup_load(mm, ss + 2*C, TS, 0, false); + s8x8_t ms; + simdgroup_load(ms, ss + 2*C, TS, 0, false); #pragma unroll(DV8) for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[i], ms, lo[i]); } } // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - s8x8_t ms; - simdgroup_load(ms, ss + 8*cc, TS, 0, false); + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, TS, 0, false); if (is_same::value) { // we can read directly from global memory @@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 - simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); } } else { for (short ii = 0; ii < DV16; ii += 4) { @@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); } } else { if (ii + tx < DV16) { @@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); } } } @@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } + for (short j = tiisg; j < Q; j += NW) { + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); + + // store result to shared memory in F32 + if (sgitg == 0) { + for (short i = 0; i < DV8; ++i) { + //simdgroup_store(lo[i], so + i*8, DV, 0, false); + simdgroup_float8x8 t(1.0f); + simdgroup_multiply(t, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < DV8; ++i) { - simdgroup_store(lo[i], so + i*8, DV, 0, false); - } - } + for (short j = tiisg; j < Q; j += NW) { + const float S0 = ss[j*TS - 1*SH + 0]; + const float S1 = ss[j*TS + 0]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TS + 0]; - const float S1 = ss[j*TS + sg*SH + 0]; - - const float M0 = ss[j*TS + 1]; - const float M1 = ss[j*TS + sg*SH + 1]; + const float M0 = ss[j*TS - 1*SH + 1]; + const float M1 = ss[j*TS + 1]; const float M = max(M0, M1); - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); + float ms0 = exp(M0 - M); + float ms1 = exp(M1 - M); const float S = S0*ms0 + S1*ms1; - if (tiisg == 0) { - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; - ss[j*TS + 2*C + j ] = ms0; - ss[j*TS + 2*C + j + sg*SH] = ms1; - } + ss[j*TS + 2*C + j - 1*SH] = ms0; + ss[j*TS + 2*C + j ] = ms1; } + //simdgroup_barrier(mem_flags::mem_threadgroup); + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { s8x8_t ms0; s8x8_t ms1; - simdgroup_load(ms0, ss + 2*C, TS, 0, false); - simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); + simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); + simdgroup_load(ms1, ss + 2*C, TS, 0, false); #pragma unroll(DV8) for (short i = 0; i < DV8; ++i) { - o8x8_t t; + simdgroup_float8x8 t; simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms1, t); + simdgroup_multiply(t, ms0, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(t, ms1, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - simdgroup_store(lo[i], so + i*8, DV, 0, false); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); // final rescale with 1/S and store to global memory for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { @@ -3723,8 +3718,8 @@ kernel void kernel_flash_attn_ext( half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ float, simdgroup_float8x8, \ - float, float4, simdgroup_float8x8 - //half, half4, simdgroup_half8x8 + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ @@ -3732,8 +3727,8 @@ kernel void kernel_flash_attn_ext( bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ float, simdgroup_float8x8, \ - float, float4, simdgroup_float8x8 - //half, half4, simdgroup_half8x8 + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl new file mode 100644 index 000000000..7ccf41efb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // subgroup size +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = 0; + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_id_q4_0_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global char * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb02, + int ne10, + int ne11, + int ne12, + ulong nb11, + ulong nb12, + int ne20, + int ne21, + ulong nb21, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float *)((global char *)src1 + offset1); + src2 = (global char *)((global char *)src2 + offset2); + dst = (global float *)((global char *)dst + offsetd); + + const int iid1 = get_group_id(2)/ne20; + const int idx = get_group_id(2)%ne20; + + const int i02 = ((global int *)(src2 + iid1*nb21))[idx]; + + const int i11 = idx%ne11; + const int i12 = iid1; + + const int i1 = idx; + const int i2 = i12; + + global char * src0_q_cur = src0_q + (i02*nb02/nb00)*(QK4_0/2); + global half * src0_d_cur = src0_d + (i02*nb02/nb00); + global float * src1_cur = (global float *)((global char *) src1 + i11*nb11 + i12*nb12); + global float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + mul_vec_q_n_f32_8x_flat(src0_q_cur, src0_d_cur, src1_cur, dst_cur, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4a8a0dba6..70068398f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -94,7 +94,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de -#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 #define GGML_VK_MAX_NODES 8192 @@ -118,25 +118,11 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -struct vk_queue { - uint32_t queue_family_index; - vk::Queue queue; - vk::CommandPool pool; - uint32_t cmd_buffer_idx; - std::vector cmd_buffers; - - vk::PipelineStageFlags stage_flags; - - bool transfer_only; -}; +#define MAX_PARAMETER_COUNT 8 struct vk_pipeline_struct { std::string name; vk::ShaderModule shader_module; - vk::DescriptorSetLayout dsl; - std::vector descriptor_pools; - std::vector descriptor_sets; - uint32_t descriptor_set_idx; vk::PipelineLayout layout; vk::Pipeline pipeline; uint32_t push_constant_size; @@ -183,6 +169,40 @@ struct ggml_backend_vk_buffer_type_context { vk_device device; }; +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); @@ -357,6 +377,8 @@ struct vk_device_struct { // set to true to indicate that some shaders need to be compiled after the dryrun bool need_compiles {}; + vk::DescriptorSetLayout dsl; + vk_matmul_pipeline pipeline_matmul_f32 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {}; vk_matmul_pipeline pipeline_matmul_bf16 {}; @@ -474,7 +496,6 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; - std::unordered_map pipeline_descriptor_set_requirements; std::vector> pinned_memory; @@ -499,10 +520,8 @@ struct vk_device_struct { ggml_vk_destroy_buffer(sync_staging); - device.destroyCommandPool(compute_queue.pool); - if (!single_queue) { - device.destroyCommandPool(transfer_queue.pool); - } + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); for (auto& pipeline : pipelines) { if (pipeline.second.expired()) { @@ -514,10 +533,26 @@ struct vk_device_struct { } pipelines.clear(); + device.destroyDescriptorSetLayout(dsl); + device.destroy(); } }; +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + struct vk_buffer_struct { vk::Buffer buffer = VK_NULL_HANDLE; vk::DeviceMemory device_memory = VK_NULL_HANDLE; @@ -835,7 +870,7 @@ struct vk_context_struct { std::vector in_memcpys; std::vector out_memcpys; - vk_queue * q; + vk_command_pool * p {}; }; typedef std::shared_ptr vk_context; typedef std::weak_ptr vk_context_ref; @@ -946,6 +981,14 @@ struct ggml_backend_vk_context { vk_context_ref transfer_ctx; std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -1076,39 +1119,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); - std::vector dsl_binding; - std::vector dsl_binding_flags; - for (uint32_t i = 0; i < parameter_count; i++) { - dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); - dsl_binding_flags.push_back({}); - } - - vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; - vk::PushConstantRange pcr( vk::ShaderStageFlagBits::eCompute, 0, pipeline->push_constant_size ); - vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( - {}, - dsl_binding); - descriptor_set_layout_create_info.setPNext(&dslbfci); - pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); - - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - - pipeline->descriptor_set_idx = 0; - - vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); std::vector specialization_entries(specialization_constants.size()); @@ -1183,15 +1206,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); - for (auto& pool : pipeline->descriptor_pools) { - device.destroyDescriptorPool(pool); - } - pipeline->descriptor_pools.clear(); - pipeline->descriptor_sets.clear(); - pipeline->descriptor_set_idx = 0; - - device.destroyDescriptorSetLayout(pipeline->dsl); - device.destroyPipelineLayout(pipeline->layout); device.destroyShaderModule(pipeline->shader_module); @@ -1199,97 +1213,76 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) device.destroyPipeline(pipeline->pipeline); } -static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); - device->pipeline_descriptor_set_requirements[pipeline->name] += n; + ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { pipeline->needed = true; - device->need_compiles = true; + ctx->device->need_compiles = true; } } -static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { - std::lock_guard guard(device->mutex); +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { - for (auto& pair : device->pipeline_descriptor_set_requirements) { - vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); - const uint64_t n = pair.second; + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } - VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); + vk_device& device = ctx->device; - if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { - // Enough descriptors are available - continue; + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); } - uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); - uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; - uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - while (to_alloc > 0) { - const uint32_t alloc_count = std::min(pool_remaining, to_alloc); - to_alloc -= alloc_count; - pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - if (pool_idx >= pipeline->descriptor_pools.size()) { - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - } - - std::vector layouts(alloc_count); - for (uint32_t i = 0; i < alloc_count; i++) { - layouts[i] = pipeline->dsl; - } - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); - std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); - - pool_idx++; + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; } } -static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { - VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); - pipeline->descriptor_set_idx = 0; -} - -static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); - std::lock_guard guard(device->mutex); - if (q.cmd_buffers.size() > q.cmd_buffer_idx) { + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { // Reuse command buffer - return q.cmd_buffers[q.cmd_buffer_idx++]; + return p.cmd_buffers[p.cmd_buffer_idx++]; } vk::CommandBufferAllocateInfo command_buffer_alloc_info( - q.pool, + p.pool, vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); auto buf = cmd_buffers.front(); - q.cmd_buffers.push_back(buf); - q.cmd_buffer_idx++; + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; return buf; } -static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { - VK_LOG_DEBUG("ggml_vk_create_submission()"); - vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); - return s; -} - static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { if (ctx->seqs.empty()) { if (fence) { - ctx->q->queue.submit({}, fence); + ctx->p->q->queue.submit({}, fence); } return; } @@ -1328,7 +1321,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { tl_signal_vals.push_back({}); tl_signal_semaphores.push_back({}); for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { - stage_flags[idx].push_back(ctx->q->stage_flags); + stage_flags[idx].push_back(ctx->p->q->stage_flags); tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); } @@ -1358,7 +1351,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { } } - ctx->q->queue.submit(submit_infos, fence); + ctx->p->q->queue.submit(submit_infos, fence); ctx->seqs.clear(); } @@ -1416,28 +1409,25 @@ static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_ q.queue_family_index = queue_family_index; q.transfer_only = transfer_only; - vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); - q.pool = device->device.createCommandPool(command_pool_create_info_compute); - - q.cmd_buffer_idx = 0; + q.cmd_pool.init(device, &q); q.queue = device->device.getQueue(queue_family_index, queue_index); q.stage_flags = stage_flags; } -static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); ctx->gc.contexts.emplace_back(result); - result->q = &q; + result->p = &p; return result; } -static vk_context ggml_vk_create_temporary_context(vk_queue& q) { +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); - result->q = &q; + result->p = &p; return result; } @@ -1470,15 +1460,29 @@ static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { return ctx->gc.events[ctx->event_idx++]; } -static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { - VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); - std::lock_guard guard(device->mutex); +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); // Requires command buffers to be done - device->device.resetCommandPool(q.pool); - q.cmd_buffer_idx = 0; + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; } +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { vk::MemoryType memory_type = mem_props->memoryTypes[i]; @@ -1497,8 +1501,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor printf("\nWARNING: Requested buffer size (%zu) exceeds device memory allocation limit (%zu)!\n",size,device->max_memory_allocation_size); } - std::lock_guard guard(device->mutex); - vk_buffer buf = std::make_shared(); if (size == 0) { @@ -1627,11 +1629,11 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { static void ggml_vk_sync_buffers(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->q->transfer_only; + const bool transfer_queue = ctx->p->q->transfer_only; ctx->s->buffer.pipelineBarrier( - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -1650,8 +1652,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ctx->s->buffer.waitEvents( events, - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, {}, {} @@ -3393,6 +3395,22 @@ static vk_device ggml_vk_get_device(size_t idx) { } } + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + ggml_vk_load_shaders(device); if (!device->single_queue) { @@ -3400,7 +3418,8 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); } else { // TODO: Use pointer or reference to avoid copy - device->transfer_queue = device->compute_queue; + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); } device->buffer_type = { @@ -3619,11 +3638,11 @@ static void ggml_vk_instance_init() { vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); - // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + std::string devices(devices_env); std::replace(devices.begin(), devices.end(), ',', ' '); @@ -3639,9 +3658,9 @@ static void ggml_vk_instance_init() { } else { std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // Make sure at least one device exists + // If no vulkan devices are found, return early if (devices.empty()) { - std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); return; } @@ -3724,9 +3743,20 @@ static void ggml_vk_instance_init() { } } - // If no dedicated GPUs found, fall back to GPU 0 + // If no dedicated GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. if (vk_instance.device_indices.empty()) { - vk_instance.device_indices.push_back(0); + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; } } GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); @@ -3755,6 +3785,9 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -4120,9 +4153,9 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf } } -static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); + s.buffer = ggml_vk_create_cmd_buffer(device, p); if (one_time) { s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); } else { @@ -4167,10 +4200,10 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); - GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); - GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); - vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); @@ -4207,7 +4240,7 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { ggml_vk_ctx_end(subctx); } - subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } @@ -4408,7 +4441,9 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); } } else { - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); ggml_vk_ctx_end(subctx); @@ -4420,6 +4455,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } } @@ -4496,7 +4532,9 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ memcpy(dst, (uint8_t *) src->ptr + offset, size); } else { - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); ggml_vk_ctx_end(subctx); @@ -4504,6 +4542,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -4523,15 +4562,17 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); // Copy within the device - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); } else { VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device @@ -4556,7 +4597,8 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); @@ -4564,6 +4606,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { @@ -4977,18 +5020,18 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } if (quantize_y) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); } return; } @@ -5170,12 +5213,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5308,7 +5351,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -5397,7 +5440,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); return; } @@ -5584,12 +5627,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } return; } @@ -5778,12 +5821,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -6103,9 +6146,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } return; } @@ -6668,7 +6711,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -7049,7 +7092,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -7188,7 +7231,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -7866,9 +7909,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -7883,7 +7926,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -7925,7 +7968,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { ggml_vk_matmul( @@ -7941,6 +7984,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); double time = std::chrono::duration_cast(end-begin).count() / 1000.0; @@ -8042,16 +8086,13 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); ggml_vk_destroy_buffer(d_D); - ggml_pipeline_cleanup(p); - ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); - free(x); free(y); free(d); @@ -8129,17 +8170,17 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_quantize_data(x, qx, ne, quant); ggml_vk_dequantize_data(qx, x_ref, ne, quant); - ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_request_descriptor_sets(ctx, p, 1); if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); @@ -8150,6 +8191,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -8229,17 +8271,17 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); // -// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); // // if (ctx->device->need_compiles) { // ggml_vk_load_shaders(ctx->device); // } // -// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// ggml_pipeline_allocate_descriptor_sets(ctx); // // ggml_vk_buffer_write(x_buf, 0, x, x_sz); // -// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); // ggml_vk_ctx_begin(ctx->device, subctx); // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); // ggml_vk_ctx_end(subctx); @@ -8249,6 +8291,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // ggml_vk_submit(subctx, ctx->fence); // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); // ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); // // auto end = std::chrono::high_resolution_clock::now(); // @@ -8388,9 +8431,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -8401,19 +8444,19 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); if (mmq) { for (size_t i = 0; i < num_it; i++) { @@ -8442,6 +8485,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -8756,7 +8800,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod if (!dryrun) { if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { @@ -8810,7 +8854,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return false; } default: @@ -9202,19 +9246,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { } ctx->gc.temp_buffers.clear(); - for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { - vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; - - if (plr.expired()) { - continue; - } - - vk_pipeline pl = plr.lock(); - ggml_pipeline_cleanup(pl); - } - - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -9235,7 +9268,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->tensor_ctxs.clear(); ctx->gc.contexts.clear(); - ctx->device->pipeline_descriptor_set_requirements.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; } // Clean up on backend free @@ -9262,6 +9296,15 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->device->device.destroyFence(ctx->fence); ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); } static int ggml_vk_get_device_count() { @@ -9528,7 +9571,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9551,7 +9594,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9574,7 +9617,7 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -9635,7 +9678,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_load_shaders(ctx->device); } ggml_vk_preallocate_buffers(ctx); - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); int last_node = cgraph->n_nodes - 1; @@ -9667,7 +9710,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); GGML_ASSERT(ctx->compute_ctx.expired()); - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); @@ -9702,7 +9745,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (vk_perf_logger_enabled) { if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 93dd1d802..439fc1afe 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -333,7 +333,9 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc11", # nomic-bert "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe "model.layers.{bid}.mlp.c_fc", # starcoder2 - "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used) + "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU) + "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU) "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone @@ -370,7 +372,7 @@ class TensorNameMap: "model.layers.layers.{bid}.mlp.gate_proj", # plamo "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert - "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 + "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b7148e054..08eea301d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - const int64_t n_kv = kv_state->get_n_kv(); - - if (s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); - float * data = (float *) s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - data[i] = kv_state->s_mask(i); - } - } -} - void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -650,6 +634,7 @@ ggml_tensor * llm_graph_context::build_ffn( { // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); @@ -663,7 +648,7 @@ ggml_tensor * llm_graph_context::build_ffn( { // Split into two equal parts int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); @@ -986,23 +971,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_mask; - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1455,43 +1423,53 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_copy_mask_state( +ggml_tensor * llm_graph_context::build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const { + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { const auto * kv_state = static_cast(mstate); const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx0, states, state_copy); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx0, states, state_mask); + ggml_tensor * output_states; - // copy states which won't be changed further (between n_seqs and n_kv) + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } + + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)), - ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); + return output_states; } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -1502,8 +1480,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_copy_mask_state( - gf, token_shift_all, state_copy, state_mask, + ggml_tensor * token_shift = build_recurrent_state( + gf, token_shift_all, state_copy, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 28da6a522..88fb77f1d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -200,18 +200,6 @@ public: const llama_kv_cache_recurrent_state * kv_state; }; -class llm_graph_input_s_mask : public llm_graph_input_i { -public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} - virtual ~llm_graph_input_s_mask() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * s_mask; // F32 [1, n_kv] - - const llama_kv_cache_recurrent_state * kv_state; -}; - class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -521,7 +509,6 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_s_copy() const; - ggml_tensor * build_inp_s_mask() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -606,18 +593,17 @@ struct llm_graph_context { // recurrent // - ggml_tensor * build_copy_mask_state( + ggml_tensor * build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const; + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 87fed8ded..3ed46bb45 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche bool success = true; - // TODO: here we have to verify that all ubatches can fit in the cells - // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells - // during the compute of each ubatch. to reproduce, uncomment the following loop and run: - // - // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 - // - // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed - // - GGML_UNUSED(ubatches); - //for (const auto & ubatch : ubatches) { - // if (!find_slot(ubatch)) { - // success = false; - // break; - // } - //} + for (const auto & ubatch : ubatches) { + if (!find_slot(ubatch)) { + success = false; + break; + } + } // restore the original state cells = std::move(org_cells); @@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche } bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seq_tokens = ubatch.n_seq_tokens; // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*n_tokens) { + if (head > used + 2*n_seqs) { head = 0; } @@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { empty_cell.src = orig_cell.src; orig_cell.seq_id.erase(seq_id); empty_cell.seq_id.insert(seq_id); // will be overwritten + GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id } seq_meta.tail = next_empty_cell; // find next empty cell if (s + 1 < n_seqs) { - next_empty_cell += 1; for (uint32_t i = 0; i < size; ++i) { + next_empty_cell += 1; if (next_empty_cell >= size) { next_empty_cell -= size; } kv_cell & cell = cells[next_empty_cell]; if (cell.is_empty()) { break; } - next_empty_cell += 1; } } } @@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + const int32_t dst_id = s + min; + const int32_t src_id = cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { kv_cell & dst_cell = cells[dst_id]; kv_cell & src_cell = cells[src_id]; @@ -563,12 +553,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { std::swap(dst_cell.src, src_cell.src); std::swap(dst_cell.seq_id, src_cell.seq_id); - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; + // swap tails + for (uint32_t i = 0; i < size; ++i) { + int32_t & tail = cells[i].tail; + if (tail == src_id) { + tail = dst_id; + } else if (tail == dst_id) { + tail = src_id; + } } } } @@ -576,7 +568,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; + const int32_t cell_id = s + min; kv_cell & cell = cells[cell_id]; if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { @@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { } } + // Find first cell without src refs, to use as the zero-ed state + { + // TODO: bake-in src refcounts in the cell metadata + std::vector refcounts(size, 0); + for (size_t i = 0; i < size; ++i) { + const int32_t src = cells[i].src; + if (src >= 0) { + refcounts[src] += 1; + } + } + + rs_z = -1; + for (int i = min; i <= max; ++i) { + if (refcounts[i] == 0) { + rs_z = i; + break; + } + } + + for (int i = min; i <= max; ++i) { + if (cells[i].src < 0) { + GGML_ASSERT(rs_z >= 0); + cells[i].src0 = rs_z; + } else { + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + } + cells[i].src = i; // avoid moving or clearing twice + } + } + // allow getting the range of used cells, from head to head + n head = min; n = max - min + 1; @@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { } bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; + // shifting the pos is trivial for recurrent models + return true; } size_t llama_kv_cache_recurrent::total_size() const { @@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const { return is_full ? 0 : kv->head; } +int32_t llama_kv_cache_recurrent_state::get_rs_z() const { + return is_full ? 0 : kv->rs_z; +} + uint32_t llama_kv_cache_recurrent_state::get_size() const { return kv->size; } @@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { } int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { - return kv->s_copy(i); -} - -float llama_kv_cache_recurrent_state::s_mask(int i) const { - return kv->s_mask(i); + return kv->cells[i + kv->head].src0; } diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index d1da12256..4b33bafd7 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -57,10 +57,6 @@ public: bool get_can_shift() const override; - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -73,10 +69,14 @@ public: // computed before each graph build uint32_t n = 0; + // first zero-ed state + int32_t rs_z = -1; + // TODO: optimize for recurrent state needs struct kv_cell { llama_pos pos = -1; - int32_t src = -1; // used to copy states + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id; @@ -157,13 +157,13 @@ public: uint32_t get_n_kv() const; uint32_t get_head() const; + int32_t get_rs_z() const; uint32_t get_size() const; ggml_tensor * get_k_l(int32_t il) const; ggml_tensor * get_v_l(int32_t il) const; int32_t s_copy(int i) const; - float s_mask(int i) const; private: const llama_memory_status status; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 4f1fb3ab5..7655688b3 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -127,6 +127,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); + debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } void llama_kv_cache_unified::clear(bool data) { @@ -462,7 +465,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d for (uint32_t i = 0; i < n_kv; ++i) { assert(dinfo.ids[i] <= n_kv); - if (dinfo.ids[i] == n_kv) { + if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { continue; } @@ -512,21 +515,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { head_cur = 0; } - // otherwise, one cell per token. - if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); return -1; } -//#define FIND_SLOT_DEBUG 1 -#if FIND_SLOT_DEBUG - LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa); + if (debug > 0) { + LLAMA_LOG_CONT("\n"); + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); - // for debugging - { - std::string ss; - if (n_swa > 0) { + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; for (uint32_t i = 0; i < cells.size(); ++i) { if (cells.is_empty(i)) { ss += '.'; @@ -534,21 +533,45 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { ss += std::to_string(cells.seq_get(i)); } if (i%256 == 255) { + ss += " *"; ss += '\n'; } } - } - LLAMA_LOG_WARN("\n%s\n", ss.c_str()); - } - - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { - if (cells.seq_pos_min(s) < 0) { - continue; + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); } - LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + std::string cur; + if (cells.is_empty(i)) { + cur = '.'; + } else { + cur = std::to_string(cells.pos_get(i)); + } + const int n = cur.size(); + for (int j = 0; j < 5 - n; ++j) { + cur += ' '; + } + ss += cur; + if (i%256 == 255) { + ss += " *"; + } + if (i%64 == 63) { + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } } -#endif uint32_t n_tested = 0; @@ -559,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { continue; } - // keep track of what the minimum sequence positions would be if we accept the ubatch - llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; - for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { - seq_pos_min[s] = cells.seq_pos_min(s); - } - bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - const llama_pos pos = ubatch.pos[i]; - const llama_seq_id seq_id = ubatch.seq_id[i][0]; + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; // can we use this cell? either: // - the cell is empty // - the cell is occupied only by one sequence: - // - mask causally, if the sequence is the same as the one we are inserting + // - (disabled) mask causally, if the sequence is the same as the one we are inserting // - mask SWA, using current max pos for that sequence in the cache // always insert in the cell with minimum pos bool can_use = cells.is_empty(head_cur + i); @@ -581,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { if (!can_use && cells.seq_count(head_cur + i) == 1) { const llama_pos pos_cell = cells.pos_get(head_cur + i); - // causal mask - if (cells.seq_has(head_cur + i, seq_id)) { - can_use = pos_cell >= pos; - } + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(head_cur + i, seq_id)) { + // can_use = pos_cell >= pos; + //} if (!can_use) { const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); // SWA mask - // note: we insert only in the cell with minimum pos in order to preserve the invariant that - // all positions between [pos_min, pos_max] for each sequence will be present in the cache - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - if (pos_cell == seq_pos_min[seq_id_cell] && - is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - seq_pos_min[seq_id_cell]++; + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -623,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { } void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos_max_rm[s] = -1; + } + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { if (!cells.is_empty(head_cur + i)) { + assert(cells.seq_count(head_cur + i) == 1); + + const llama_seq_id seq_id = cells.seq_get(head_cur + i); + const llama_pos pos = cells.pos_get(head_cur + i); + + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(head_cur + i); } @@ -635,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch } } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } + + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } + } + // move the head at the end of the slot head = head_cur + ubatch.n_tokens; } @@ -944,11 +987,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - //GGML_ASSERT(kv_self->size == n_ctx); - auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 49f410ef6..cf4c691ba 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -158,6 +158,8 @@ private: // SWA const uint32_t n_swa = 0; + int debug = 0; + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 9e2c4d927..acf30aebe 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -80,6 +80,9 @@ public: assert(isrc < pos.size()); assert(idst < pos.size()); + assert(pos[idst] == -1); + assert(pos[isrc] != -1); + pos [idst] = pos [isrc]; shift[idst] = shift[isrc]; seq [idst] = seq [isrc]; @@ -144,9 +147,10 @@ public: assert(pos[i] != -1); seq_pos_rm(i); + seq[i].reset(); pos[i] = -1; - seq[i].reset(); + shift[i] = 0; used.erase(i); } @@ -164,6 +168,7 @@ public: if (seq[i].none()) { pos[i] = -1; + shift[i] = 0; used.erase(i); @@ -192,6 +197,7 @@ public: seq[i].reset(); pos[i] = -1; + shift[i] = 0; used.erase(i); @@ -317,21 +323,20 @@ public: pos[i] += d; shift[i] += d; - seq_pos_add(i); - has_shift = true; if (pos[i] < 0) { - seq_pos_rm(i); - seq[i].reset(); pos[i] = -1; + shift[i] = 0; used.erase(i); return true; } + seq_pos_add(i); + return false; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 730ecd7bd..2bfdd81e7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2320,8 +2320,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -6143,7 +6143,7 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - LLM_FFN_GELU, LLM_FFN_PAR, il); + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { cur = build_ffn(cur, @@ -8957,7 +8957,6 @@ struct llm_build_mamba : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8966,8 +8965,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); - cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il); + cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -9008,7 +9006,6 @@ struct llm_build_mamba : public llm_graph_context { ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -9035,12 +9032,12 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states - ggml_tensor * conv = build_copy_mask_state( - gf, conv_states_all, state_copy, state_mask, + ggml_tensor * conv = build_recurrent_state( + gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_copy_mask_state( - gf, ssm_states_all, state_copy, state_mask, + ggml_tensor * ssm = build_recurrent_state( + gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); @@ -11756,7 +11753,6 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -11880,8 +11876,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_state->get_v_l(il), state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11937,7 +11933,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11948,7 +11943,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -11964,7 +11959,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12035,7 +12030,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12046,7 +12040,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -12059,7 +12053,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12151,7 +12145,6 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -12234,8 +12227,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_state->get_v_l(il), state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12293,7 +12286,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12304,7 +12296,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -12320,7 +12312,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12387,7 +12379,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12398,7 +12389,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -12411,7 +12402,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh new file mode 100755 index 000000000..86e839133 --- /dev/null +++ b/tests/test-tokenizers-repo.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +if [ $# -lt 2 ]; then + printf "Usage: $0 []\n" + exit 1 +fi + +if [ $# -eq 3 ]; then + toktest=$3 +else + toktest="./test-tokenizer-0" +fi + +if [ ! -x $toktest ]; then + printf "Test executable \"$toktest\" not found!\n" + exit 1 +fi + +repo=$1 +folder=$2 + +if [ -d $folder ] && [ -d $folder/.git ]; then + (cd $folder; git pull) +else + git clone $repo $folder +fi + +shopt -s globstar +for gguf in $folder/**/*.gguf; do + if [ -f $gguf.inp ] && [ -f $gguf.out ]; then + $toktest $gguf + else + printf "Found \"$gguf\" without matching inp/out files, ignoring...\n" + fi +done + diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 35b9e702f..0fb01665a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 77dcbc11b..1b1cf439b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -233,6 +233,7 @@ struct server_task { slot_params defaults; defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -2060,6 +2061,7 @@ struct server_context { SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; + slot.params.n_keep = params_base.n_keep; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -3556,9 +3558,6 @@ struct server_context { const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - // keep track of total number of tokens generated in the draft - slot.n_draft_total += draft.size(); - // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); @@ -3566,6 +3565,9 @@ struct server_context { continue; } + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + // construct the speculation batch common_batch_clear(slot.batch_spec); common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); @@ -3584,7 +3586,7 @@ struct server_context { slot.n_past += ids.size(); slot.n_decoded += ids.size(); - // update how many tokens out of draft was accepted + // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; slot.cache_tokens.push_back(id); diff --git a/tools/server/webui/src/index.scss b/tools/server/webui/src/index.scss index 64460b740..362db6e17 100644 --- a/tools/server/webui/src/index.scss +++ b/tools/server/webui/src/index.scss @@ -41,6 +41,10 @@ html { max-width: 900px; } +.chat-bubble { + @apply break-words; +} + .chat-bubble-base-300 { --tw-bg-opacity: 1; --tw-text-opacity: 1;