mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build.yml # .github/workflows/release.yml # CMakeLists.txt # examples/simple-chat/simple-chat.cpp # src/llama-quant.cpp # tools/run/run.cpp # tools/server/README.md
This commit is contained in:
commit
ace537d44e
17 changed files with 554 additions and 212 deletions
|
@ -21,4 +21,5 @@
|
||||||
#include "arch/s390/quants.c"
|
#include "arch/s390/quants.c"
|
||||||
#else
|
#else
|
||||||
#pragma message("KoboldCpp Cannot Compile Quants! Unknown Architecture!")
|
#pragma message("KoboldCpp Cannot Compile Quants! Unknown Architecture!")
|
||||||
|
#error "Compilation halted due to unknown architecture."
|
||||||
#endif
|
#endif
|
|
@ -18,4 +18,5 @@
|
||||||
#pragma message("KoboldCpp Compiling Repack for S390X")
|
#pragma message("KoboldCpp Compiling Repack for S390X")
|
||||||
#else
|
#else
|
||||||
#pragma message("KoboldCpp Cannot Compile Repack! Unknown Architecture!")
|
#pragma message("KoboldCpp Cannot Compile Repack! Unknown Architecture!")
|
||||||
|
#error "Compilation halted due to unknown architecture."
|
||||||
#endif
|
#endif
|
|
@ -266,6 +266,14 @@ static bool fp16_mma_hardware_available(const int cc) {
|
||||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool bf16_mma_hardware_available(const int cc) {
|
||||||
|
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool fp32_mma_hardware_available(const int cc) {
|
||||||
|
return GGML_CUDA_CC_IS_CDNA(cc);
|
||||||
|
}
|
||||||
|
|
||||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||||
static bool new_mma_available(const int cc) {
|
static bool new_mma_available(const int cc) {
|
||||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||||
|
|
|
@ -1944,16 +1944,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||||
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
||||||
|
|
||||||
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
|
||||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||||
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||||
|
|
||||||
bool any_gpus_with_slow_fp16 = false;
|
bool any_gpus_with_slow_fp16 = false;
|
||||||
bool any_gpus_without_fp16_mma = false;
|
|
||||||
|
|
||||||
if (split) {
|
if (split) {
|
||||||
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
||||||
|
@ -1964,16 +1962,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
||||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
||||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
|
@ -1984,7 +1982,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||||
|
|
||||||
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
if (!split && use_mul_mat_vec) {
|
||||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
||||||
|
|
|
@ -2,25 +2,26 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "mmv.cuh"
|
#include "mmv.cuh"
|
||||||
|
|
||||||
template <typename T, typename type_acc, int block_size>
|
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||||
static __global__ void mul_mat_vec(
|
static __global__ void mul_mat_vec(
|
||||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||||
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
const int64_t row = blockIdx.x;
|
const int row = blockIdx.x;
|
||||||
const int64_t channel_dst = blockIdx.y;
|
const int channel_dst = blockIdx.y;
|
||||||
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||||
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||||
const int64_t sample_dst = blockIdx.z;
|
const int sample_dst = blockIdx.z;
|
||||||
const int64_t sample_x = sample_dst / sample_ratio;
|
const int sample_x = sample_dst / sample_ratio;
|
||||||
const int64_t sample_y = sample_dst;
|
const int sample_y = sample_dst;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
|
|
||||||
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||||
|
|
||||||
const float2 * y2 = (const float2 *) y;
|
const float2 * y2 = (const float2 *) y;
|
||||||
|
|
||||||
|
@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf[ncols_dst] = {0.0f};
|
||||||
|
|
||||||
if constexpr (std::is_same<T, float>::value) {
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
const float2 * x2 = (const float2 *) x;
|
const float2 * x2 = (const float2 *) x;
|
||||||
|
|
||||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const float2 tmpx = x2[col2];
|
const float2 tmpx = x2[col2];
|
||||||
const float2 tmpy = y2[col2];
|
|
||||||
sumf += tmpx.x*tmpy.x;
|
#pragma unroll
|
||||||
sumf += tmpx.y*tmpy.y;
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
|
sumf[j] += tmpx.x*tmpy.x;
|
||||||
|
sumf[j] += tmpx.y*tmpy.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same<T, half>::value) {
|
} else if constexpr (std::is_same<T, half>::value) {
|
||||||
const half2 * x2 = (const half2 *) x;
|
const half2 * x2 = (const half2 *) x;
|
||||||
|
|
||||||
if (std::is_same<type_acc, float>::value) {
|
if (std::is_same<type_acc, float>::value) {
|
||||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const float2 tmpx = __half22float2(x2[col2]);
|
const float2 tmpx = __half22float2(x2[col2]);
|
||||||
const float2 tmpy = y2[col2];
|
|
||||||
sumf += tmpx.x * tmpy.x;
|
#pragma unroll
|
||||||
sumf += tmpx.y * tmpy.y;
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
|
sumf[j] += tmpx.x * tmpy.x;
|
||||||
|
sumf[j] += tmpx.y * tmpy.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
half2 sumh2 = make_half2(0.0f, 0.0f);
|
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const float2 tmp = y2[col2];
|
const half2 tmpx = x2[col2];
|
||||||
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
|
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = __low2float(sumh2) + __high2float(sumh2);
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||||
const int * x2 = (const int *) x;
|
const int * x2 = (const int *) x;
|
||||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const int tmpx = x2[col2];
|
const int tmpx = x2[col2];
|
||||||
const float2 tmpy = y2[col2];
|
#pragma unroll
|
||||||
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
|
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||||
|
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
static_assert(std::is_same<T, void>::value, "unsupported type");
|
static_assert(std::is_same<T, void>::value, "unsupported type");
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||||
|
|
||||||
if (block_size > warp_size) {
|
if (block_size > warp_size) {
|
||||||
buf_iw[tid/warp_size] = sumf;
|
buf_iw[tid/warp_size] = sumf[j];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (tid >= warp_size) {
|
if (tid < warp_size) {
|
||||||
return;
|
sumf[j] = buf_iw[tid];
|
||||||
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||||
|
}
|
||||||
|
if (j < ncols_dst) {
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sumf = buf_iw[tid];
|
|
||||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid != 0) {
|
if (tid >= ncols_dst) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[row] = sumf;
|
dst[tid*stride_col_dst + row] = sumf[tid];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename type_acc>
|
template <typename T, typename type_acc, int ncols_dst>
|
||||||
static void launch_mul_mat_vec_cuda(
|
static void launch_mul_mat_vec_cuda(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t ncols, const int64_t nrows,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % 2 == 0);
|
GGML_ASSERT(ncols % 2 == 0);
|
||||||
GGML_ASSERT(stride_row % 2 == 0);
|
GGML_ASSERT(stride_row % 2 == 0);
|
||||||
|
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||||
|
@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
|
||||||
const dim3 block_dims(block_size_best, 1, 1);
|
const dim3 block_dims(block_size_best, 1, 1);
|
||||||
switch (block_size_best) {
|
switch (block_size_best) {
|
||||||
case 32: {
|
case 32: {
|
||||||
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 64: {
|
case 64: {
|
||||||
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 96: {
|
case 96: {
|
||||||
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 160: {
|
case 160: {
|
||||||
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 192: {
|
case 192: {
|
||||||
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 224: {
|
case 224: {
|
||||||
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
case 256: {
|
case 256: {
|
||||||
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename type_acc>
|
||||||
|
static void mul_mat_vec_cuda_switch_ncols_dst(
|
||||||
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
switch (ncols_dst) {
|
||||||
|
case 1:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 1>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 2>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 3>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 4>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 5>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 6>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 7>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
launch_mul_mat_vec_cuda<T, type_acc, 8>
|
||||||
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mul_mat_vec_cuda(
|
static void mul_mat_vec_cuda(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||||
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
enum ggml_prec prec, cudaStream_t stream) {
|
enum ggml_prec prec, cudaStream_t stream) {
|
||||||
if constexpr(std::is_same<T, half>::value) {
|
if constexpr(std::is_same<T, half>::value) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
launch_mul_mat_vec_cuda<T, half>
|
mul_mat_vec_cuda_switch_ncols_dst<T, half>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
launch_mul_mat_vec_cuda<T, float>
|
mul_mat_vec_cuda_switch_ncols_dst<T, float>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||||
|
|
||||||
GGML_ASSERT(ncols_dst == 1);
|
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half * src0_d = (const half *) src0->data;
|
const half * src0_d = (const half *) src0->data;
|
||||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
|
@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne0 = dst->ne[0];
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
GGML_ASSERT(src1_ncols == 1);
|
const int id = ggml_cuda_get_device();
|
||||||
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
||||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||||
|
|
||||||
|
|
||||||
// ggml_cuda_op provides single, contiguous matrices
|
// ggml_cuda_op provides single, contiguous matrices
|
||||||
const int64_t stride_row = ne00;
|
const int64_t stride_row = ne00;
|
||||||
|
const int64_t stride_col_y = ne10;
|
||||||
|
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
|
||||||
const int64_t nchannels_x = 1;
|
const int64_t nchannels_x = 1;
|
||||||
const int64_t nchannels_y = 1;
|
const int64_t nchannels_y = 1;
|
||||||
const int64_t nchannels_dst = 1;
|
const int64_t nchannels_dst = 1;
|
||||||
|
@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0_dd_i;
|
const float * src0_d = (const float *) src0_dd_i;
|
||||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half * src0_d = (const half *) src0_dd_i;
|
const half * src0_d = (const half *) src0_dd_i;
|
||||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
|
@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec(
|
||||||
GGML_UNUSED(src1_ncols);
|
GGML_UNUSED(src1_ncols);
|
||||||
GGML_UNUSED(src1_padded_row_size);
|
GGML_UNUSED(src1_padded_row_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
||||||
|
if (src0_ne[0] % 2 != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||||
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||||
|
return ne11 <= 8;
|
||||||
|
}
|
||||||
|
if (cc >= GGML_CUDA_CC_TURING) {
|
||||||
|
return ne11 <= 4;
|
||||||
|
}
|
||||||
|
return ne11 <= 3;
|
||||||
|
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||||
|
if (fp32_mma_hardware_available(cc)) {
|
||||||
|
return ne11 <= 3;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||||
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
||||||
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||||
|
return src0_small && ne11 <= 4;
|
||||||
|
}
|
||||||
|
if (fp16_mma_hardware_available(cc)) {
|
||||||
|
return src0_small && ne11 <= 3;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||||
|
if (fp16_mma_hardware_available(cc)) {
|
||||||
|
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||||
|
return ne11 <= 5;
|
||||||
|
}
|
||||||
|
return ne11 <= 2;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||||
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
||||||
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||||
|
return src0_small && ne11 <= 4;
|
||||||
|
}
|
||||||
|
if (bf16_mma_hardware_available(cc)) {
|
||||||
|
return src0_small && ne11 <= 3;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||||
|
if (bf16_mma_hardware_available(cc)) {
|
||||||
|
return ne11 <= 3;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
}
|
||||||
|
return ne11 <= 8;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
|
||||||
#define MMV_MAX_ROWS 512
|
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_mul_mat_vec(
|
void ggml_cuda_op_mul_mat_vec(
|
||||||
|
@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||||
|
|
||||||
|
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|
||||||
|
|
|
@ -167,81 +167,81 @@ class SpecialVocab:
|
||||||
tokenizer_config['bos_token'] = special_bos = special_cls
|
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||||
if not special_eos and special_sep and tokenizer_config:
|
if not special_eos and special_sep and tokenizer_config:
|
||||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||||
post_processor = tokenizer.get('post_processor', {})
|
if post_processor := tokenizer.get('post_processor'):
|
||||||
for processor in post_processor.get('processors', [post_processor]):
|
for processor in post_processor.get('processors', [post_processor]):
|
||||||
if processor.get('type') == 'RobertaProcessing':
|
if processor.get('type') == 'RobertaProcessing':
|
||||||
self.add_special_token['bos'] = True
|
self.add_special_token['bos'] = True
|
||||||
self.add_special_token['eos'] = True
|
self.add_special_token['eos'] = True
|
||||||
self.add_special_token['sep'] = True
|
self.add_special_token['sep'] = True
|
||||||
if not special_cls and tokenizer_config:
|
if not special_cls and tokenizer_config:
|
||||||
special_cls = processor.get('cls', [special_bos])[0]
|
special_cls = processor.get('cls', [special_bos])[0]
|
||||||
tokenizer_config['cls_token'] = special_cls
|
tokenizer_config['cls_token'] = special_cls
|
||||||
if not special_sep and tokenizer_config:
|
if not special_sep and tokenizer_config:
|
||||||
special_sep = processor.get('sep', [special_eos])[0]
|
special_sep = processor.get('sep', [special_eos])[0]
|
||||||
tokenizer_config['sep_token'] = special_sep
|
tokenizer_config['sep_token'] = special_sep
|
||||||
continue
|
continue
|
||||||
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
||||||
# Only works with simple templates, **will** get it wrong on unusual sequences
|
# Only works with simple templates, **will** get it wrong on unusual sequences
|
||||||
if processor.get('type') == 'TemplateProcessing':
|
if processor.get('type') == 'TemplateProcessing':
|
||||||
tmpl_single = processor.get('single', [])
|
tmpl_single = processor.get('single', [])
|
||||||
tmpl_pair = processor.get('pair', [])
|
tmpl_pair = processor.get('pair', [])
|
||||||
special_first = None
|
special_first = None
|
||||||
special_last = None
|
special_last = None
|
||||||
if len(tmpl_single) > 1:
|
if len(tmpl_single) > 1:
|
||||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
special_bos = special_first
|
special_bos = special_first
|
||||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||||
if special_first not in (special_bos, special_cls):
|
if special_first not in (special_bos, special_cls):
|
||||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
special_eos = special_last
|
special_eos = special_last
|
||||||
elif special_last != special_eos:
|
elif special_last != special_eos:
|
||||||
if 'eot' not in self.special_token_types:
|
if 'eot' not in self.special_token_types:
|
||||||
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
||||||
tokenizer_config['eot_token'] = special_eos
|
tokenizer_config['eot_token'] = special_eos
|
||||||
elif 'eom' not in self.special_token_types:
|
elif 'eom' not in self.special_token_types:
|
||||||
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
||||||
tokenizer_config['eom_token'] = special_eos
|
tokenizer_config['eom_token'] = special_eos
|
||||||
else:
|
else:
|
||||||
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
||||||
tokenizer_config['eos_token'] = special_eos = special_last
|
tokenizer_config['eos_token'] = special_eos = special_last
|
||||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||||
if special_last != special_eos:
|
if special_last != special_eos:
|
||||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||||
if tmpl_pair:
|
if tmpl_pair:
|
||||||
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||||
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||||
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
||||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||||
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
||||||
if tmpl_a != 'A' or tmpl_b != 'B':
|
if tmpl_a != 'A' or tmpl_b != 'B':
|
||||||
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
||||||
# A [sep] [eos] B
|
# A [sep] [eos] B
|
||||||
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
||||||
add_sep = False
|
add_sep = False
|
||||||
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
||||||
if special_entry in (special_sep, special_eos) and not special_last:
|
if special_entry in (special_sep, special_eos) and not special_last:
|
||||||
add_sep = True
|
|
||||||
if special_entry not in (special_sep, special_eos):
|
|
||||||
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
|
||||||
else:
|
|
||||||
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
|
||||||
if len(tmpl_pair) == 2:
|
|
||||||
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
|
||||||
if special_entry in (special_sep, special_eos):
|
|
||||||
add_sep = True
|
add_sep = True
|
||||||
if special_entry not in (special_sep, special_eos):
|
if special_entry not in (special_sep, special_eos):
|
||||||
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||||
else:
|
else:
|
||||||
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
||||||
self.add_special_token['sep'] = add_sep
|
if len(tmpl_pair) == 2:
|
||||||
if add_sep and not special_sep and tokenizer_config:
|
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
||||||
tokenizer_config['sep_token'] = special_eos
|
if special_entry in (special_sep, special_eos):
|
||||||
continue
|
add_sep = True
|
||||||
|
if special_entry not in (special_sep, special_eos):
|
||||||
|
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||||
|
else:
|
||||||
|
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
||||||
|
self.add_special_token['sep'] = add_sep
|
||||||
|
if add_sep and not special_sep and tokenizer_config:
|
||||||
|
tokenizer_config['sep_token'] = special_eos
|
||||||
|
continue
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
return True
|
return True
|
||||||
chat_template_alt = None
|
chat_template_alt = None
|
||||||
|
|
|
@ -393,6 +393,7 @@ extern "C" {
|
||||||
void * imatrix; // pointer to importance matrix data
|
void * imatrix; // pointer to importance matrix data
|
||||||
void * kv_overrides; // pointer to vector containing overrides
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
void * tensor_types; // pointer to vector containing tensor types
|
void * tensor_types; // pointer to vector containing tensor types
|
||||||
|
void * prune_layers; // pointer to vector containing layer indices to prune
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
typedef struct llama_logit_bias {
|
typedef struct llama_logit_bias {
|
||||||
|
@ -946,12 +947,14 @@ extern "C" {
|
||||||
// Requires the context to have a memory.
|
// Requires the context to have a memory.
|
||||||
// For encode-decoder contexts, processes the batch using the decoder.
|
// For encode-decoder contexts, processes the batch using the decoder.
|
||||||
// Positive return values does not mean a fatal error, but rather a warning.
|
// Positive return values does not mean a fatal error, but rather a warning.
|
||||||
// Upon non-zero return values, the memory state is restored to the state before this call
|
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
|
||||||
|
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||||
|
// Upon other return values, the memory state is restored to the state before this call
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||||
// 2 - aborted
|
// 2 - aborted (processed ubatches will remain in the context's memory)
|
||||||
// -1 - invalid input batch
|
// -1 - invalid input batch
|
||||||
// < -1 - error
|
// < -1 - fatal error (processed ubatches will remain in the context's memory)
|
||||||
LLAMA_API int32_t llama_decode(
|
LLAMA_API int32_t llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch);
|
struct llama_batch batch);
|
||||||
|
|
|
@ -61,7 +61,7 @@ logit_bias_max = 512
|
||||||
dry_seq_break_max = 128
|
dry_seq_break_max = 128
|
||||||
|
|
||||||
# global vars
|
# global vars
|
||||||
KcppVersion = "1.94.1"
|
KcppVersion = "1.94.2"
|
||||||
showdebug = True
|
showdebug = True
|
||||||
kcpp_instance = None #global running instance
|
kcpp_instance = None #global running instance
|
||||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}
|
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}
|
||||||
|
|
|
@ -244,22 +244,34 @@ bool llama_batch_allocr::init(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (memory) {
|
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
||||||
|
|
||||||
|
if (p0 >= 0) {
|
||||||
|
bool ok = true;
|
||||||
|
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
if (seq_pos_min(s) != p0 + 1) {
|
||||||
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
ok = false;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
assert(batch.embd);
|
assert(batch.embd);
|
||||||
|
|
||||||
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
||||||
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
||||||
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
ok = false;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!ok) {
|
||||||
|
LLAMA_LOG_ERROR(
|
||||||
|
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||||
|
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||||
|
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||||
|
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||||
|
__func__, s, s, p0, s, seq_pos_min(s));
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||||
|
|
|
@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
||||||
// this template requires the model to have "\n\n" as EOT token
|
// this template requires the model to have "\n\n" as EOT token
|
||||||
for (auto message : chat) {
|
for (size_t i = 0; i < chat.size(); i++) {
|
||||||
std::string role(message->role);
|
std::string role(chat[i]->role);
|
||||||
if (role == "user") {
|
if (role == "system") {
|
||||||
ss << "User: " << message->content << "\n\nAssistant:";
|
ss << "System: " << trim(chat[i]->content) << "\n\n";
|
||||||
} else {
|
} else if (role == "user") {
|
||||||
ss << message->content << "\n\n";
|
ss << "User: " << trim(chat[i]->content) << "\n\n";
|
||||||
|
if (i == chat.size() - 1) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
||||||
|
|
|
@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fix sequence indexing
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
const auto & seq_id = ubatch.seq_id[i][0];
|
const auto & seq_id = ubatch.seq_id[i][0];
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||||
// TODO: add unit tests
|
// TODO: add unit tests
|
||||||
|
@ -164,7 +165,7 @@ public:
|
||||||
assert(seq_id >= 0);
|
assert(seq_id >= 0);
|
||||||
|
|
||||||
seq[i].reset(seq_id);
|
seq[i].reset(seq_id);
|
||||||
seq_pos[seq_id].erase(pos[i]);
|
seq_pos_dec(seq_id, pos[i]);
|
||||||
|
|
||||||
if (seq[i].none()) {
|
if (seq[i].none()) {
|
||||||
pos[i] = -1;
|
pos[i] = -1;
|
||||||
|
@ -187,7 +188,7 @@ public:
|
||||||
seq[i].reset();
|
seq[i].reset();
|
||||||
|
|
||||||
seq[i].set(seq_id);
|
seq[i].set(seq_id);
|
||||||
seq_pos[seq_id].insert(pos[i]);
|
seq_pos_inc(seq_id, pos[i]);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -232,7 +233,7 @@ public:
|
||||||
assert(!seq[i].test(seq_id));
|
assert(!seq[i].test(seq_id));
|
||||||
|
|
||||||
seq[i].set(seq_id);
|
seq[i].set(seq_id);
|
||||||
seq_pos[seq_id].insert(pos[i]);
|
seq_pos_inc(seq_id, pos[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// return the sequence id of this cell
|
// return the sequence id of this cell
|
||||||
|
@ -259,7 +260,9 @@ public:
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return *seq_pos[seq_id].begin();
|
assert(seq_pos[seq_id].begin()->second > 0);
|
||||||
|
|
||||||
|
return seq_pos[seq_id].begin()->first;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the maximum position of sequence seq_id currently present in any of the cells
|
// the maximum position of sequence seq_id currently present in any of the cells
|
||||||
|
@ -272,7 +275,9 @@ public:
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return *seq_pos[seq_id].rbegin();
|
assert(seq_pos[seq_id].rbegin()->second > 0);
|
||||||
|
|
||||||
|
return seq_pos[seq_id].rbegin()->first;
|
||||||
}
|
}
|
||||||
|
|
||||||
// note: call only if the cell is not empty
|
// note: call only if the cell is not empty
|
||||||
|
@ -389,17 +394,36 @@ private:
|
||||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||||
std::vector<seq_set_t> seq;
|
std::vector<seq_set_t> seq;
|
||||||
|
|
||||||
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
||||||
|
// if the position p is not present, seq_pos[s][p] is not set
|
||||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||||
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
//
|
||||||
|
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
||||||
|
// - during performing a cache reuse via (rm + add)
|
||||||
|
// - some vision models have input embeddings with repeating positions
|
||||||
|
//
|
||||||
|
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
||||||
|
|
||||||
// helper functions for updating `seq_pos`, once cell at a time:
|
// helper functions for updating `seq_pos`, once cell at a time:
|
||||||
|
|
||||||
|
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
||||||
|
auto it = seq_pos[s].find(p);
|
||||||
|
assert(it != seq_pos[s].end());
|
||||||
|
|
||||||
|
if (--it->second == 0) {
|
||||||
|
seq_pos[s].erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
||||||
|
seq_pos[s][p]++;
|
||||||
|
}
|
||||||
|
|
||||||
// remove cell i
|
// remove cell i
|
||||||
void seq_pos_rm(uint32_t i) {
|
void seq_pos_rm(uint32_t i) {
|
||||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq[i].test(s)) {
|
if (seq[i].test(s)) {
|
||||||
seq_pos[s].erase(pos[i]);
|
seq_pos_dec(s, pos[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -408,7 +432,7 @@ private:
|
||||||
void seq_pos_add(uint32_t i) {
|
void seq_pos_add(uint32_t i) {
|
||||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||||
if (seq[i].test(s)) {
|
if (seq[i].test(s)) {
|
||||||
seq_pos[s].insert(pos[i]);
|
seq_pos_inc(s, pos[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
#include "llama-quant.h"
|
#include "llama-quant.h"
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
#include "llama-model-loader.h"
|
#include "llama-model-loader.h"
|
||||||
|
@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
|
||||||
|
if (prune.empty()) {
|
||||||
|
return orig_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
||||||
|
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
||||||
|
const int blk = std::stoi(match[1]);
|
||||||
|
std::string new_name = orig_name;
|
||||||
|
|
||||||
|
if (mapped.count(blk)) {
|
||||||
|
// Already mapped, do nothing
|
||||||
|
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
|
||||||
|
mapped[blk] = "";
|
||||||
|
} else if (blk < prune.front()) {
|
||||||
|
mapped[blk] = std::to_string(blk);
|
||||||
|
next_id = blk + 1;
|
||||||
|
} else {
|
||||||
|
mapped[blk] = std::to_string(next_id);
|
||||||
|
++next_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return orig_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
|
||||||
|
if (mapped.empty()) {
|
||||||
|
return orig_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
||||||
|
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
||||||
|
const std::string blk(match[1]);
|
||||||
|
std::string new_name = orig_name;
|
||||||
|
|
||||||
|
for (const auto & p : mapped) {
|
||||||
|
if (p.second == blk) {
|
||||||
|
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
|
||||||
|
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return orig_name;
|
||||||
|
}
|
||||||
|
|
||||||
struct quantize_state_impl {
|
struct quantize_state_impl {
|
||||||
const llama_model & model;
|
const llama_model & model;
|
||||||
const llama_model_quantize_params * params;
|
const llama_model_quantize_params * params;
|
||||||
|
@ -571,6 +620,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
||||||
gguf_context_ptr ctx_out { gguf_init_empty() };
|
gguf_context_ptr ctx_out { gguf_init_empty() };
|
||||||
|
|
||||||
|
std::vector<int> prune_list = {};
|
||||||
|
if (params->prune_layers) {
|
||||||
|
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
|
||||||
|
}
|
||||||
|
|
||||||
// copy the KV pairs from the input file
|
// copy the KV pairs from the input file
|
||||||
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
||||||
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
||||||
|
@ -600,12 +654,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::map<int, std::string> mapped;
|
||||||
|
int blk_id = 0;
|
||||||
|
int pruned_attention_w = 0;
|
||||||
|
|
||||||
// make a list of weights
|
// make a list of weights
|
||||||
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
||||||
tensors.reserve(ml.weights_map.size());
|
tensors.reserve(ml.weights_map.size());
|
||||||
for (const auto & it : ml.weights_map) {
|
for (const auto & it : ml.weights_map) {
|
||||||
|
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
||||||
|
if (remapped_name.empty()) {
|
||||||
|
if (it.first.find("attn_v.weight") != std::string::npos ||
|
||||||
|
it.first.find("attn_qkv.weight") != std::string::npos ||
|
||||||
|
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
||||||
|
pruned_attention_w++;
|
||||||
|
}
|
||||||
|
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
||||||
|
continue;
|
||||||
|
} else if (remapped_name != it.first) {
|
||||||
|
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
||||||
|
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
||||||
|
}
|
||||||
tensors.push_back(&it.second);
|
tensors.push_back(&it.second);
|
||||||
}
|
}
|
||||||
|
if (!prune_list.empty()) {
|
||||||
|
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
|
||||||
|
}
|
||||||
|
|
||||||
// keep_split requires that the weights are sorted by split index
|
// keep_split requires that the weights are sorted by split index
|
||||||
if (params->keep_split) {
|
if (params->keep_split) {
|
||||||
|
@ -643,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
if (llama_model_has_encoder(&model)) {
|
if (llama_model_has_encoder(&model)) {
|
||||||
n_attn_layer *= 3;
|
n_attn_layer *= 3;
|
||||||
}
|
}
|
||||||
GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
GGML_ASSERT_CONTINUE((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t total_size_org = 0;
|
size_t total_size_org = 0;
|
||||||
|
@ -684,7 +758,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
||||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
||||||
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
||||||
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
|
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -835,7 +909,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
|
|
||||||
const float * imatrix = nullptr;
|
const float * imatrix = nullptr;
|
||||||
if (imatrix_data) {
|
if (imatrix_data) {
|
||||||
auto it = imatrix_data->find(tensor->name);
|
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
|
||||||
if (it == imatrix_data->end()) {
|
if (it == imatrix_data->end()) {
|
||||||
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
||||||
} else {
|
} else {
|
||||||
|
@ -950,6 +1024,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
||||||
/*.imatrix =*/ nullptr,
|
/*.imatrix =*/ nullptr,
|
||||||
/*.kv_overrides =*/ nullptr,
|
/*.kv_overrides =*/ nullptr,
|
||||||
/*.tensor_type =*/ nullptr,
|
/*.tensor_type =*/ nullptr,
|
||||||
|
/*.prune_layers =*/ nullptr
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -293,6 +293,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (!params.system_prompt.empty() || !params.prompt.empty()) {
|
if (!params.system_prompt.empty() || !params.prompt.empty()) {
|
||||||
common_chat_templates_inputs inputs;
|
common_chat_templates_inputs inputs;
|
||||||
|
inputs.use_jinja = g_params->use_jinja;
|
||||||
inputs.messages = chat_msgs;
|
inputs.messages = chat_msgs;
|
||||||
inputs.add_generation_prompt = !params.prompt.empty();
|
inputs.add_generation_prompt = !params.prompt.empty();
|
||||||
|
|
||||||
|
@ -917,10 +918,19 @@ int main(int argc, char ** argv) {
|
||||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||||
|
|
||||||
|
if (params.verbose_prompt) {
|
||||||
|
LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size() - original_size);
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
||||||
const llama_token token = embd_inp[i];
|
const llama_token token = embd_inp[i];
|
||||||
|
const std::string token_str = common_token_to_piece(ctx, token);
|
||||||
output_tokens.push_back(token);
|
output_tokens.push_back(token);
|
||||||
output_ss << common_token_to_piece(ctx, token);
|
output_ss << token_str;
|
||||||
|
|
||||||
|
if (params.verbose_prompt) {
|
||||||
|
LOG_INF("%6d -> '%s'\n", token, token_str.c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset assistant message
|
// reset assistant message
|
||||||
|
|
|
@ -108,13 +108,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// usage:
|
|
||||||
// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
|
|
||||||
//
|
|
||||||
[[noreturn]]
|
[[noreturn]]
|
||||||
static void usage(const char * executable) {
|
static void usage(const char * executable) {
|
||||||
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
|
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
|
||||||
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n");
|
||||||
|
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
|
||||||
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
|
||||||
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
|
||||||
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
|
||||||
|
@ -125,6 +123,8 @@ static void usage(const char * executable) {
|
||||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||||
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
|
||||||
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
|
||||||
|
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
|
||||||
|
printf(" Advanced option to remove all tensors from the given layers\n");
|
||||||
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||||
|
@ -287,6 +287,32 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers) {
|
||||||
|
if (!data) {
|
||||||
|
printf("\n%s: no layer pruning ids provided\n\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto block_ids = string_split<std::string>(data, ',');
|
||||||
|
for (const auto & block_id : block_ids) {
|
||||||
|
int id;
|
||||||
|
try {
|
||||||
|
id = std::stoi(block_id);
|
||||||
|
} catch (...) {
|
||||||
|
id = -1;
|
||||||
|
}
|
||||||
|
if (id < 0) {
|
||||||
|
printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
prune_layers.emplace_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
sort(prune_layers.begin(), prune_layers.end());
|
||||||
|
prune_layers.erase(std::unique(prune_layers.begin(), prune_layers.end()), prune_layers.end());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
if (argc < 3) {
|
if (argc < 3) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
|
@ -299,6 +325,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<std::string> included_weights, excluded_weights;
|
std::vector<std::string> included_weights, excluded_weights;
|
||||||
std::vector<llama_model_kv_override> kv_overrides;
|
std::vector<llama_model_kv_override> kv_overrides;
|
||||||
std::vector<tensor_quantization> tensor_types;
|
std::vector<tensor_quantization> tensor_types;
|
||||||
|
std::vector<int> prune_layers;
|
||||||
|
|
||||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||||
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
|
||||||
|
@ -325,6 +352,10 @@ int main(int argc, char ** argv) {
|
||||||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
}
|
}
|
||||||
|
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
|
||||||
|
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
|
||||||
|
usage(argv[0]);
|
||||||
|
}
|
||||||
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
|
||||||
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
|
||||||
usage(argv[0]);
|
usage(argv[0]);
|
||||||
|
@ -412,6 +443,9 @@ int main(int argc, char ** argv) {
|
||||||
if (!tensor_types.empty()) {
|
if (!tensor_types.empty()) {
|
||||||
params.tensor_types = &tensor_types;
|
params.tensor_types = &tensor_types;
|
||||||
}
|
}
|
||||||
|
if (!prune_layers.empty()) {
|
||||||
|
params.prune_layers = &prune_layers;
|
||||||
|
}
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
|
||||||
|
|
|
@ -3418,9 +3418,12 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret < -1) {
|
if (ret < -1) {
|
||||||
|
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||||
err = "Compute error.";
|
err = "Compute error.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: handle ret == 2 (abort) when we start aborting
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue