diff --git a/ggml/src/ggml-cpu/kcpp-quantmapper.c b/ggml/src/ggml-cpu/kcpp-quantmapper.c index fd5a27a85..59ed593f3 100644 --- a/ggml/src/ggml-cpu/kcpp-quantmapper.c +++ b/ggml/src/ggml-cpu/kcpp-quantmapper.c @@ -21,4 +21,5 @@ #include "arch/s390/quants.c" #else #pragma message("KoboldCpp Cannot Compile Quants! Unknown Architecture!") +#error "Compilation halted due to unknown architecture." #endif \ No newline at end of file diff --git a/ggml/src/ggml-cpu/kcpp-repackmapper.cpp b/ggml/src/ggml-cpu/kcpp-repackmapper.cpp index 75d84bddc..ad46f0acf 100644 --- a/ggml/src/ggml-cpu/kcpp-repackmapper.cpp +++ b/ggml/src/ggml-cpu/kcpp-repackmapper.cpp @@ -18,4 +18,5 @@ #pragma message("KoboldCpp Compiling Repack for S390X") #else #pragma message("KoboldCpp Cannot Compile Repack! Unknown Architecture!") +#error "Compilation halted due to unknown architecture." #endif \ No newline at end of file diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index d90731e1e..1e18324c2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -266,6 +266,14 @@ static bool fp16_mma_hardware_available(const int cc) { GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); } +static bool bf16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; +} + +static bool fp32_mma_hardware_available(const int cc) { + return GGML_CUDA_CC_IS_CDNA(cc); +} + // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 495ae8e16..e41064bb1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1944,16 +1944,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - bool any_gpus_with_slow_fp16 = false; - bool any_gpus_without_fp16_mma = false; + bool any_gpus_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1964,16 +1962,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); - any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } // debug helpers @@ -1984,7 +1982,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index d8c385e23..e14c93516 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -2,25 +2,26 @@ #include "common.cuh" #include "mmv.cuh" -template +template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, - const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, - const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { - const int64_t row = blockIdx.x; - const int64_t channel_dst = blockIdx.y; - const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int64_t sample_dst = blockIdx.z; - const int64_t sample_x = sample_dst / sample_ratio; - const int64_t sample_y = sample_dst; - const int tid = threadIdx.x; + const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const int row = blockIdx.x; + const int channel_dst = blockIdx.y; + const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; - y += sample_y *stride_sample_y + channel_y *stride_channel_y; - dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -34,81 +35,108 @@ static __global__ void mul_mat_vec( __syncthreads(); } - float sumf = 0.0f; + float sumf[ncols_dst] = {0.0f}; if constexpr (std::is_same::value) { const float2 * x2 = (const float2 *) x; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = x2[col2]; - const float2 tmpy = y2[col2]; - sumf += tmpx.x*tmpy.x; - sumf += tmpx.y*tmpy.y; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += tmpx.x*tmpy.x; + sumf[j] += tmpx.y*tmpy.y; + } } } else if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; if (std::is_same::value) { - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); - const float2 tmpy = y2[col2]; - sumf += tmpx.x * tmpy.x; - sumf += tmpx.y * tmpy.y; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += tmpx.x * tmpy.x; + sumf[j] += tmpx.y * tmpy.y; + } } } else { #ifdef FP16_AVAILABLE - half2 sumh2 = make_half2(0.0f, 0.0f); + half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const float2 tmp = y2[col2]; - sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const half2 tmpx = x2[col2]; + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + } } - sumf = __low2float(sumh2) + __high2float(sumh2); +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } } else if constexpr (std::is_same::value) { const int * x2 = (const int *) x; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { - const int tmpx = x2[col2]; - const float2 tmpy = y2[col2]; - sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumf[j] += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf[j] += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + } } } else { static_assert(std::is_same::value, "unsupported type"); } - sumf = warp_reduce_sum(sumf); +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = warp_reduce_sum(sumf[j]); - if (block_size > warp_size) { - buf_iw[tid/warp_size] = sumf; - __syncthreads(); - if (tid >= warp_size) { - return; + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf[j]; + __syncthreads(); + if (tid < warp_size) { + sumf[j] = buf_iw[tid]; + sumf[j] = warp_reduce_sum(sumf[j]); + } + if (j < ncols_dst) { + __syncthreads(); + } } - sumf = buf_iw[tid]; - sumf = warp_reduce_sum(sumf); } - if (tid != 0) { + if (tid >= ncols_dst) { return; } - dst[row] = sumf; + dst[tid*stride_col_dst + row] = sumf[tid]; } -template +template static void launch_mul_mat_vec_cuda( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); const int64_t channel_ratio = nchannels_dst / nchannels_x; @@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda( const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { - mul_mat_vec<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, - stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + mul_mat_vec<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda( } } +template +static void mul_mat_vec_cuda_switch_ncols_dst( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 2: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 3: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 4: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 5: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 6: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 7: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + case 8: + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + template static void mul_mat_vec_cuda( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same::value) { if (prec == GGML_PREC_DEFAULT) { - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + mul_mat_vec_cuda_switch_ncols_dst + (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } - launch_mul_mat_vec_cuda - (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + mul_mat_vec_cuda_switch_ncols_dst + (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } @@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(ncols_dst == 1); + GGML_ASSERT(!ids || ncols_dst == 1); switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec( GGML_ASSERT(dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - GGML_ASSERT(src1_ncols == 1); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; // ggml_cuda_op provides single, contiguous matrices const int64_t stride_row = ne00; + const int64_t stride_col_y = ne10; + const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; const int64_t nchannels_dst = 1; @@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec( switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; @@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec( GGML_UNUSED(src1_ncols); GGML_UNUSED(src1_padded_row_size); } + +bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { + if (src0_ne[0] % 2 != 0) { + return false; + } + switch (type) { + case GGML_TYPE_F32: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return ne11 <= 8; + } + if (cc >= GGML_CUDA_CC_TURING) { + return ne11 <= 4; + } + return ne11 <= 3; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (fp32_mma_hardware_available(cc)) { + return ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + case GGML_TYPE_F16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (fp16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (fp16_mma_hardware_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + return ne11 <= 5; + } + return ne11 <= 2; + } + return ne11 <= 8; + } + return ne11 <= 8; + case GGML_TYPE_BF16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (bf16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (bf16_mma_hardware_available(cc)) { + return ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh index 756e7e1cc..1330bcb6a 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -1,8 +1,5 @@ #include "common.cuh" -// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available -#define MMV_MAX_ROWS 512 - void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_op_mul_mat_vec( @@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 3b08f6134..3f541b0c0 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -167,81 +167,81 @@ class SpecialVocab: tokenizer_config['bos_token'] = special_bos = special_cls if not special_eos and special_sep and tokenizer_config: tokenizer_config['eos_token'] = special_eos = special_sep - post_processor = tokenizer.get('post_processor', {}) - for processor in post_processor.get('processors', [post_processor]): - if processor.get('type') == 'RobertaProcessing': - self.add_special_token['bos'] = True - self.add_special_token['eos'] = True - self.add_special_token['sep'] = True - if not special_cls and tokenizer_config: - special_cls = processor.get('cls', [special_bos])[0] - tokenizer_config['cls_token'] = special_cls - if not special_sep and tokenizer_config: - special_sep = processor.get('sep', [special_eos])[0] - tokenizer_config['sep_token'] = special_sep - continue - # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added - # Only works with simple templates, **will** get it wrong on unusual sequences - if processor.get('type') == 'TemplateProcessing': - tmpl_single = processor.get('single', []) - tmpl_pair = processor.get('pair', []) - special_first = None - special_last = None - if len(tmpl_single) > 1: - if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_bos = special_first - self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False - if special_first not in (special_bos, special_cls): - logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') - if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): - if not tokenizer_config: - special_eos = special_last - elif special_last != special_eos: - if 'eot' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eot', ) - tokenizer_config['eot_token'] = special_eos - elif 'eom' not in self.special_token_types: - self.special_token_types = tuple(self.special_token_types) + ('eom', ) - tokenizer_config['eom_token'] = special_eos - else: - logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') - tokenizer_config['eos_token'] = special_eos = special_last - self.add_special_token['eos'] = True if special_last == special_eos else False - if special_last != special_eos: - logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') - if tmpl_pair: - seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 - seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None - if (special_first and seq_start == 0) or (special_last and seq_stop is None): - logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') - if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: - tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') - tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') - if tmpl_a != 'A' or tmpl_b != 'B': - logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') - # A [sep] [eos] B - if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): - add_sep = False - if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos) and not special_last: - add_sep = True - if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') - else: - logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') - if len(tmpl_pair) == 2: - if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): - if special_entry in (special_sep, special_eos): + if post_processor := tokenizer.get('post_processor'): + for processor in post_processor.get('processors', [post_processor]): + if processor.get('type') == 'RobertaProcessing': + self.add_special_token['bos'] = True + self.add_special_token['eos'] = True + self.add_special_token['sep'] = True + if not special_cls and tokenizer_config: + special_cls = processor.get('cls', [special_bos])[0] + tokenizer_config['cls_token'] = special_cls + if not special_sep and tokenizer_config: + special_sep = processor.get('sep', [special_eos])[0] + tokenizer_config['sep_token'] = special_sep + continue + # Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added + # Only works with simple templates, **will** get it wrong on unusual sequences + if processor.get('type') == 'TemplateProcessing': + tmpl_single = processor.get('single', []) + tmpl_pair = processor.get('pair', []) + special_first = None + special_last = None + if len(tmpl_single) > 1: + if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_bos = special_first + self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False + if special_first not in (special_bos, special_cls): + logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing') + if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'): + if not tokenizer_config: + special_eos = special_last + elif special_last != special_eos: + if 'eot' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eot', ) + tokenizer_config['eot_token'] = special_eos + elif 'eom' not in self.special_token_types: + self.special_token_types = tuple(self.special_token_types) + ('eom', ) + tokenizer_config['eom_token'] = special_eos + else: + logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!') + tokenizer_config['eos_token'] = special_eos = special_last + self.add_special_token['eos'] = True if special_last == special_eos else False + if special_last != special_eos: + logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing') + if tmpl_pair: + seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0 + seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None + if (special_first and seq_start == 0) or (special_last and seq_stop is None): + logger.warning('TemplateProcessing leading/trailing special tokens do not match TemplateProcessing') + if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]: + tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id') + tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id') + if tmpl_a != 'A' or tmpl_b != 'B': + logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing') + # A [sep] [eos] B + if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]): + add_sep = False + if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos) and not special_last: add_sep = True if special_entry not in (special_sep, special_eos): - logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing') else: - logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') - self.add_special_token['sep'] = add_sep - if add_sep and not special_sep and tokenizer_config: - tokenizer_config['sep_token'] = special_eos - continue + logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing') + if len(tmpl_pair) == 2: + if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'): + if special_entry in (special_sep, special_eos): + add_sep = True + if special_entry not in (special_sep, special_eos): + logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing') + else: + logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing') + self.add_special_token['sep'] = add_sep + if add_sep and not special_sep and tokenizer_config: + tokenizer_config['sep_token'] = special_eos + continue if not tokenizer_config: return True chat_template_alt = None diff --git a/include/llama.h b/include/llama.h index 00aa6a05f..5fb5d2696 100644 --- a/include/llama.h +++ b/include/llama.h @@ -393,6 +393,7 @@ extern "C" { void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types + void * prune_layers; // pointer to vector containing layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -946,12 +947,14 @@ extern "C" { // Requires the context to have a memory. // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. - // Upon non-zero return values, the memory state is restored to the state before this call + // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context + // To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + // Upon other return values, the memory state is restored to the state before this call // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // 2 - aborted + // 2 - aborted (processed ubatches will remain in the context's memory) // -1 - invalid input batch - // < -1 - error + // < -1 - fatal error (processed ubatches will remain in the context's memory) LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch); diff --git a/koboldcpp.py b/koboldcpp.py index de5d45940..7f5605f67 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -61,7 +61,7 @@ logit_bias_max = 512 dry_seq_break_max = 128 # global vars -KcppVersion = "1.94.1" +KcppVersion = "1.94.2" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b3c996e18..5676dedcb 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -244,22 +244,34 @@ bool llama_batch_allocr::init( continue; } - if (memory) { + const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1; + + if (p0 >= 0) { + bool ok = true; + if (batch.token) { - if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) { - LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); - return false; + if (seq_pos_min(s) != p0 + 1) { + ok = false; } } else { assert(batch.embd); // for embeddings (typically used as vision input), we allow them to have repeating positions // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 - if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { - LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); - return false; + if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) { + ok = false; } } + if (!ok) { + LLAMA_LOG_ERROR( + "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" + " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" + " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" + " it is required that the sequence positions remain consecutive: Y = X + 1\n", + __func__, s, s, p0, s, seq_pos_min(s)); + + return false; + } } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0839cad3e..5d317f4ee 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -528,12 +528,17 @@ int32_t llm_chat_apply_template( } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token - for (auto message : chat) { - std::string role(message->role); - if (role == "user") { - ss << "User: " << message->content << "\n\nAssistant:"; - } else { - ss << message->content << "\n\n"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "system") { + ss << "System: " << trim(chat[i]->content) << "\n\n"; + } else if (role == "user") { + ss << "User: " << trim(chat[i]->content) << "\n\n"; + if (i == chat.size() - 1) { + ss << "Assistant:"; + } + } else if (role == "assistant") { + ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 463abb77e..5f8be5a08 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) { pos_min[s] = std::numeric_limits::max(); } - // TODO: fix sequence indexing for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { const auto & seq_id = ubatch.seq_id[i][0]; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 349e9032e..c95d63594 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -7,6 +7,7 @@ #include #include #include +#include // meta information about KV cells that can be part of multiple sequences at the same time // TODO: add unit tests @@ -164,7 +165,7 @@ public: assert(seq_id >= 0); seq[i].reset(seq_id); - seq_pos[seq_id].erase(pos[i]); + seq_pos_dec(seq_id, pos[i]); if (seq[i].none()) { pos[i] = -1; @@ -187,7 +188,7 @@ public: seq[i].reset(); seq[i].set(seq_id); - seq_pos[seq_id].insert(pos[i]); + seq_pos_inc(seq_id, pos[i]); return false; } @@ -232,7 +233,7 @@ public: assert(!seq[i].test(seq_id)); seq[i].set(seq_id); - seq_pos[seq_id].insert(pos[i]); + seq_pos_inc(seq_id, pos[i]); } // return the sequence id of this cell @@ -259,7 +260,9 @@ public: return -1; } - return *seq_pos[seq_id].begin(); + assert(seq_pos[seq_id].begin()->second > 0); + + return seq_pos[seq_id].begin()->first; } // the maximum position of sequence seq_id currently present in any of the cells @@ -272,7 +275,9 @@ public: return -1; } - return *seq_pos[seq_id].rbegin(); + assert(seq_pos[seq_id].rbegin()->second > 0); + + return seq_pos[seq_id].rbegin()->first; } // note: call only if the cell is not empty @@ -389,17 +394,36 @@ private: // the bitset seq[i] tells us which sequences are currently occupying the i-th cell std::vector seq; - // the set seq_pos[s] tells us which positions are currently present for sequence s + // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s + // if the position p is not present, seq_pos[s][p] is not set // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache - std::set seq_pos[LLAMA_MAX_SEQ]; + // + // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq: + // - during performing a cache reuse via (rm + add) + // - some vision models have input embeddings with repeating positions + // + std::map seq_pos[LLAMA_MAX_SEQ]; // helper functions for updating `seq_pos`, once cell at a time: + void seq_pos_dec(llama_seq_id s, llama_pos p) { + auto it = seq_pos[s].find(p); + assert(it != seq_pos[s].end()); + + if (--it->second == 0) { + seq_pos[s].erase(it); + } + } + + void seq_pos_inc(llama_seq_id s, llama_pos p) { + seq_pos[s][p]++; + } + // remove cell i void seq_pos_rm(uint32_t i) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq[i].test(s)) { - seq_pos[s].erase(pos[i]); + seq_pos_dec(s, pos[i]); } } } @@ -408,7 +432,7 @@ private: void seq_pos_add(uint32_t i) { for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq[i].test(s)) { - seq_pos[s].insert(pos[i]); + seq_pos_inc(s, pos[i]); } } } diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 5a92e6bd0..56f1042ef 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1,5 +1,4 @@ #include "llama-quant.h" - #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" @@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) { } } +static std::string remap_layer(const std::string & orig_name, const std::vector & prune, std::map & mapped, int & next_id) { + if (prune.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const int blk = std::stoi(match[1]); + std::string new_name = orig_name; + + if (mapped.count(blk)) { + // Already mapped, do nothing + } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) { + mapped[blk] = ""; + } else if (blk < prune.front()) { + mapped[blk] = std::to_string(blk); + next_id = blk + 1; + } else { + mapped[blk] = std::to_string(next_id); + ++next_id; + } + + return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]); + } + + return orig_name; +} + +static std::string remap_imatrix (const std::string & orig_name, const std::map & mapped) { + if (mapped.empty()) { + return orig_name; + } + + static const std::regex pattern(R"(blk\.(\d+)\.)"); + if (std::smatch match; std::regex_search(orig_name, match, pattern)) { + const std::string blk(match[1]); + std::string new_name = orig_name; + + for (const auto & p : mapped) { + if (p.second == blk) { + LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); + return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); + } + } + GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str()); + } + + return orig_name; +} + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -571,6 +620,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const size_t align = GGUF_DEFAULT_ALIGNMENT; gguf_context_ptr ctx_out { gguf_init_empty() }; + std::vector prune_list = {}; + if (params->prune_layers) { + prune_list = *static_cast *>(params->prune_layers); + } + // copy the KV pairs from the input file gguf_set_kv (ctx_out.get(), ml.meta.get()); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV @@ -600,12 +654,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::map mapped; + int blk_id = 0; + int pruned_attention_w = 0; + // make a list of weights std::vector tensors; tensors.reserve(ml.weights_map.size()); for (const auto & it : ml.weights_map) { + const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id)); + if (remapped_name.empty()) { + if (it.first.find("attn_v.weight") != std::string::npos || + it.first.find("attn_qkv.weight") != std::string::npos || + it.first.find("attn_kv_b.weight") != std::string::npos) { + pruned_attention_w++; + } + LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); + continue; + } else if (remapped_name != it.first) { + ggml_set_name(it.second.tensor, remapped_name.c_str()); + LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); + } tensors.push_back(&it.second); } + if (!prune_list.empty()) { + gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id); + } // keep_split requires that the weights are sorted by split index if (params->keep_split) { @@ -643,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (llama_model_has_encoder(&model)) { n_attn_layer *= 3; } - GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; @@ -684,7 +758,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < ctx_outs.size(); ++i) { gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); - gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size()); } } @@ -835,7 +909,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix = nullptr; if (imatrix_data) { - auto it = imatrix_data->find(tensor->name); + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { @@ -950,6 +1024,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, + /*.prune_layers =*/ nullptr }; return result; diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 64b157b90..54fb9c8cd 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -293,6 +293,7 @@ int main(int argc, char ** argv) { if (!params.system_prompt.empty() || !params.prompt.empty()) { common_chat_templates_inputs inputs; + inputs.use_jinja = g_params->use_jinja; inputs.messages = chat_msgs; inputs.add_generation_prompt = !params.prompt.empty(); @@ -917,10 +918,19 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + if (params.verbose_prompt) { + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size() - original_size); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; + const std::string token_str = common_token_to_piece(ctx, token); output_tokens.push_back(token); - output_ss << common_token_to_piece(ctx, token); + output_ss << token_str; + + if (params.verbose_prompt) { + LOG_INF("%6d -> '%s'\n", token, token_str.c_str()); + } } // reset assistant message diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index f8d1ae749..efbf27ae6 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -108,13 +108,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp return false; } -// usage: -// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads] -// [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); - printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); + printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); @@ -125,6 +123,8 @@ static void usage(const char * executable) { printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); + printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); + printf(" Advanced option to remove all tensors from the given layers\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -287,6 +287,32 @@ static bool parse_tensor_type(const char * data, std::vector & prune_layers) { + if (!data) { + printf("\n%s: no layer pruning ids provided\n\n", __func__); + return false; + } + + const auto block_ids = string_split(data, ','); + for (const auto & block_id : block_ids) { + int id; + try { + id = std::stoi(block_id); + } catch (...) { + id = -1; + } + if (id < 0) { + printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str()); + return false; + } + prune_layers.emplace_back(id); + } + + sort(prune_layers.begin(), prune_layers.end()); + prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end()); + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -299,6 +325,7 @@ int main(int argc, char ** argv) { std::vector included_weights, excluded_weights; std::vector kv_overrides; std::vector tensor_types; + std::vector prune_layers; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -325,6 +352,10 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { + if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); @@ -412,6 +443,9 @@ int main(int argc, char ** argv) { if (!tensor_types.empty()) { params.tensor_types = &tensor_types; } + if (!prune_layers.empty()) { + params.prune_layers = &prune_layers; + } llama_backend_init(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index aa18513e3..852352383 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3418,9 +3418,12 @@ struct server_context { } if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() err = "Compute error."; } + // TODO: handle ret == 2 (abort) when we start aborting + if (!err.empty()) { SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); for (auto & slot : slots) {