diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h index 1cb4bd7a0..faaba8f24 100644 --- a/ggml/include/ggml-virtgpu.h +++ b/ggml/include/ggml-virtgpu.h @@ -7,8 +7,6 @@ extern "C" { #endif -#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend" - GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); #ifdef __cplusplus diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 6d1a2f794..445a82a61 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -269,9 +269,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); } -static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { - return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), - _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); +static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) { + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); } #endif #elif defined(__SSSE3__) @@ -783,6 +783,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); @@ -796,10 +797,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), - _mm256_cvtepi32_ps(p_2), accum2); + const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e)); + const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e)); + accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2); } sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); @@ -831,7 +832,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif for (; ib < nb; ++ib) { - const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e); int sumi1 = 0; int sumi2 = 0; for (int j = 0; j < QK_MXFP4/2; ++j) { @@ -3818,4 +3819,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index d11020021..55f46d3c5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -76,6 +76,9 @@ // precomputed f32 table for f16 (256 KB) (simd-mappings.h) float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h) +float ggml_table_f32_e8m0_half[1 << 8]; + #if defined(__ARM_ARCH) struct ggml_arm_arch_features_type { int sve_cnt; @@ -4530,6 +4533,11 @@ void ggml_cpu_init(void) { ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f)); } + // initialize E8M0 half table (256 entries) + for (int i = 0; i < (1 << 8); ++i) { + ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i); + } + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index e367f110b..630e50654 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -116,6 +116,17 @@ extern "C" { // defined in ggml-cpu.c, initialized in ggml_cpu_init() extern float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) +// defined in ggml-cpu.c, initialized in ggml_cpu_init() +extern float ggml_table_f32_e8m0_half[1 << 8]; + +// Use lookup table for E8M0 on x86 (faster than bit manipulation) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)] +#else +#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x) +#endif + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 44a3348e2..a6df585ae 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2287,13 +2287,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - if (ne2 == 1) { + static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + if (ne2 <= 4) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } } else { - ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } } - return; } if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 607d33973..fa23c9800 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3698,13 +3698,20 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } - template -static __global__ void mul_mat_q_stream_k_fixup( - const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, - const int ncols_max) { +static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, + const int32_t * expert_bounds, + float * __restrict__ dst, + const float * __restrict__ tmp_last_tile, + const int ncols_x, + const int nrows_x, + const int ncols_dst, + const size_t stride_col_dst, + const int nchannels_y, + const size_t stride_channel_dst, + const int nsamples_y, + const size_t stride_sample_dst, + const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int ITER_K = get_iter_k(type); diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 32948e4d7..d91472024 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -4,26 +4,48 @@ #include "mmvf.cuh" #include "convert.cuh" -template +template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride) { const int row = blockIdx.x; + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; + const int tid = threadIdx.x; + + int token_idx; + int channel_x; + int channel_y; + int sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + token_idx = ids ? blockIdx.z : 0; + channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio); + channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + sample_dst = ids ? 0 : blockIdx.z; + } + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; - const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y2*2; + dst += token_idx*stride_col_dst; + } bool use_gate = false; bool use_bias = false; @@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f( if (use_gate) { gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } + + const int channel_bias = ids ? channel_x : channel_dst; + if constexpr (has_fusion) { - const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f( } } -template +template static void mul_mat_vec_f_switch_fusion( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, + const int64_t ncols, const uint3 nchannels_y, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } -template +template void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, @@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); @@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 64: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 96: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 128: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 160: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 192: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 224: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 256: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; default: { GGML_ABORT("fatal error"); @@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t ids_stride, cudaStream_t stream) { + + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + if (has_ids) { + // Single-token MUL_MAT_ID path + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 2: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 3: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 4: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 5: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 6: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 7: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 8: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { + const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); } void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, @@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE); GGML_ASSERT(ne13 == ne3); GGML_ASSERT( nb00 == ts_src0); @@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f( const float * src0_d = (const float *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a09fbdc72..a50f7c021 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels. + void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c1..ce25ccf42 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -// tell the compiler to use as many registers as it wants, see nwarps definition below -template +template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; + + uint32_t token_idx = 0; + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + } + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q( active_glu = fusion.glu_op; } - const uint32_t channel_bias = ids ? channel_x : channel_dst; float x_biases[ncols_dst] = { 0.0f }; float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; // 1. Hide latency by prefetching bias and gate here @@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y; + } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + if constexpr (is_multi_token_id) { + dst += token_idx*stride_col_dst; + } + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q( } static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } template @@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); @@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst( const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } - GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; @@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 2: { constexpr int c_ncols_dst = 2; @@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 3: { constexpr int c_ncols_dst = 3; @@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 4: { constexpr int c_ncols_dst = 4; @@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 5: { constexpr int c_ncols_dst = 5; @@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 6: { constexpr int c_ncols_dst = 6; @@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 7: { constexpr int c_ncols_dst = 7; @@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 8: { constexpr int c_ncols_dst = 8; @@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; default: GGML_ABORT("fatal error"); @@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mul_mat_vec_q_switch_type( src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, ids_stride, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 377b0d3eb..4cd3d93d8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -534,6 +534,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int nsg = 8; + const int n = op->src[1]->ne[1]; + const int k = op->src[1]->ne[0]; + + snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0); + ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1); + ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index afb091e72..d89843271 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -121,6 +121,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 150fd5a4e..238fff135 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1158,6 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 59d88b01a..640ade8f8 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -78,13 +78,14 @@ #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 -#define FC_COUNT_EQUAL 1000 +#define FC_SOLVE_TRI 1000 +#define FC_COUNT_EQUAL 1100 // op-specific constants -#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NQPSG 8 #define OP_FLASH_ATTN_EXT_NCPSG 64 -#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs @@ -733,6 +734,33 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int32_t ne00t; int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7f4cfbba2..753fcec31 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -341,6 +341,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_MUL_MAT: { n_fuse = ggml_metal_op_mul_mat(ctx, idx); @@ -1557,6 +1561,63 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -2295,7 +2356,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { // return res; //} - const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; @@ -2411,7 +2472,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); @@ -2578,9 +2639,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! - const int nkpsg = 1*ncpsg; + const int nhptg = 1; // heads per threadgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -2632,6 +2693,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); } + // note: for simplicity assume the K is larger or equal than V + GGML_ASSERT(ne10 >= ne20); + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2639,28 +2703,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes - if (smem > props_dev->max_theadgroup_memory_size/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; // workgroups // each workgroup handles nsg*nkpsg cache values @@ -2673,7 +2718,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } else { nwg = 32; nsg = 1; - while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { nsg *= 2; } } @@ -2739,7 +2784,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); } else { // sanity checks assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); @@ -2752,7 +2797,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); // sync the 2 kernels ggml_metal_op_concurrency_reset(ctx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 10686a334..2e4c7d3fa 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -60,6 +60,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 17e358d1a..c09a54e66 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2737,6 +2737,83 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; +constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; +constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; + +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const short NSG = FC_solve_tri_nsg; + const short N = FC_solve_tri_n; + const short K = FC_solve_tri_k; + const short NP = PAD2(N, NW); + + const int32_t ne02 = args.ne02; + const int32_t ne03 = args.ne03; + + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x*NSG + sgitg; + + threadgroup float * sh0 = (threadgroup float *) shmem; + + device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N; + device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01; + device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01; + + for (short rr = 0; rr < N; rr += NSG) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + threadgroup float * sh0_cur = sh0 + sgitg*NP; + + for (short t = 0; t*NW < N; ++t) { + const short idx = t*NW + tiisg; + sh0_cur[idx] = src0_ptr[idx]; + } + + src0_ptr += NSG*N; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (i01 >= args.ne10) { + continue; + } + + for (short ir = 0; ir < NSG && rr + ir < N; ++ir) { + const short r = rr + ir; + + threadgroup float * sh0_cur = sh0 + ir*NP; + + float sum = 0.0f; + + for (short t = 0; t*NW < r; ++t) { + const short idx = t*NW + tiisg; + sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + const float diag = sh0_cur[r]; + + dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag; + } + } + } +} + kernel void kernel_argmax_f32( constant ggml_metal_kargs_argmax & args, device const char * src0, @@ -5931,7 +6008,7 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, @@ -6141,11 +6218,10 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE, // head elements per thread - short Q, // queries per threadgroup - short C, // cache items per threadgroup - short NSG> // number of simd groups -void kernel_flash_attn_ext_vec_impl( + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, @@ -6162,6 +6238,7 @@ void kernel_flash_attn_ext_vec_impl( static_assert(DV % 32 == 0, "DV must be divisible by 32"); #define NWG (FC_flash_attn_ext_vec_nwg) +#define NSG (FC_flash_attn_ext_vec_nsg) #define NS10 (FC_flash_attn_ext_vec_ns10) #define NS20 (FC_flash_attn_ext_vec_ns20) @@ -6190,12 +6267,12 @@ void kernel_flash_attn_ext_vec_impl( const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results // store the result for all queries in shared memory (the O matrix from the paper) so4 += tiisg; @@ -6213,11 +6290,13 @@ void kernel_flash_attn_ext_vec_impl( // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < PK4; i += NW) { - if (iq1 < args.ne01 && i < DK4) { - sq4[i] = (q4_t) q4[i]; - } else { - sq4[i] = (q4_t) 0.0f; + if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (i < DK4) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } } } @@ -6295,7 +6374,7 @@ void kernel_flash_attn_ext_vec_impl( } // skip -INF blocks - if (simd_max(sm[tiisg]) == -INFINITY) { + if (simd_max(sm[tiisg]) <= -MAXHALF) { continue; } @@ -6569,57 +6648,11 @@ void kernel_flash_attn_ext_vec_impl( } #undef NWG +#undef NSG #undef NS10 #undef NS20 } -template< - typename q4_t, // query types in shared memory - typename k4_t, // key types in shared memory - typename v4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types - typename s4_t, - typename o4_t, // attention accumulation types - typename kd4_t, // key type in device memory - short nl_k, - void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // value type in device memory - short nl_v, - void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), - short DK, // K head size - short DV, // V head size - short NE = 4, // head elements per thread - short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup - short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext_vec & args, - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device const char * sinks, - device const char * pad, - device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { -#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg - switch (FC_flash_attn_ext_vec_nsg) { - // note: disabled cases to reduce library load time - case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - } -#undef FWD_TMPL -#undef FWD_ARGS -} - // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 938ce5b83..e664269a7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -270,6 +270,7 @@ enum vk_device_architecture { AMD_RDNA3, INTEL_XE2, NVIDIA_PRE_TURING, + NVIDIA_TURING, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -352,18 +353,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& const std::vector ext_props = device.enumerateDeviceExtensionProperties(); bool cooperative_matrix = false; + bool sm_builtins = false; // Detect "pre-turing" based on lack of coopmat support. for (const auto& properties : ext_props) { if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { cooperative_matrix = true; - break; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; } } if (!cooperative_matrix) { return vk_device_architecture::NVIDIA_PRE_TURING; } + + if (sm_builtins) { + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + + props2.pNext = &sm_props; + + device.getProperties2(&props2); + + // Turing has 32, following architectures have 48 + if (sm_props.shaderWarpsPerSM == 32) { + return vk_device_architecture::NVIDIA_TURING; + } + } } return vk_device_architecture::OTHER; } @@ -8498,6 +8515,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + if (path == FA_COOPMAT1) { const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index cbdf86d3b..0bf78123b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1037,11 +1037,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 16d42c4ae..54f4ed248 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5dde51306..515d6c163 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/src/models/openelm.cpp b/src/models/openelm.cpp index ee46a3375..fbf682ec8 100644 --- a/src/models/openelm.cpp +++ b/src/models/openelm.cpp @@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index ce6ccee0d..ee476cdc4 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -675,15 +675,12 @@ int main(int argc, char ** argv) { } } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { - int n_eval = (int) embd.size() - i; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } - + if (!embd.empty()) { + int n_eval = (int) embd.size(); LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + GGML_ASSERT(n_eval <= params.n_batch); + if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -744,7 +741,7 @@ int main(int argc, char ** argv) { common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false); ++n_consumed; - if ((int) embd.size() >= params.n_batch) { + if ((int) embd.size() == params.n_batch) { break; } }