diff --git a/common/arg.cpp b/common/arg.cpp index 2c534a4fc..115752704 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1453,7 +1453,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.swa_full = true; } - )); + ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), @@ -2066,13 +2066,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.grp_attn_w = value; } ).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(common_arg( - {"-dkvc", "--dump-kv-cache"}, - "verbose print of the KV cache", - [](common_params & params) { - params.dump_kv_cache = true; - } - )); add_opt(common_arg( {"-nkvo", "--no-kv-offload"}, "disable KV offload", diff --git a/common/common.cpp b/common/common.cpp index 4337af8ea..f44bc537b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1337,81 +1337,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// KV cache utils -// - -void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { - static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; - - printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", - view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); - - llama_kv_cache_view_cell * c_curr = view.cells; - llama_seq_id * cs_curr = view.cells_sequences; - - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - if (i % row_size == 0) { - printf("\n%5d: ", i); - } - int seq_count = 0; - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] >= 0) { seq_count++; } - } - putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]); - } - - printf("\n=== Done dumping\n"); -} - -void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { - static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - - printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", - view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); - - std::unordered_map seqs; - llama_kv_cache_view_cell * c_curr = view.cells; - llama_seq_id * cs_curr = view.cells_sequences; - - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] < 0) { continue; } - if (seqs.find(cs_curr[j]) == seqs.end()) { - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - const size_t sz = seqs.size(); - seqs[cs_curr[j]] = sz; - } - } - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - } - - printf("=== Sequence legend: "); - for (const auto & it : seqs) { - printf("%zu=%d, ", it.second, it.first); - } - printf("'+'=other sequence ids"); - - c_curr = view.cells; - cs_curr = view.cells_sequences; - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - if (i % row_size == 0) { - printf("\n%5d: ", i); - } - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] >= 0) { - const auto & it = seqs.find(cs_curr[j]); - putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); - } else { - putchar('.'); - } - } - putchar(' '); - } - - printf("\n=== Done dumping\n"); -} - // // Embedding utils // diff --git a/common/common.h b/common/common.h index 9de574323..beecd1eb4 100644 --- a/common/common.h +++ b/common/common.h @@ -326,7 +326,6 @@ struct common_params { bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation - bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data @@ -618,16 +617,6 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); -// -// KV cache utils -// - -// Dump the KV cache view with the number of sequences per cell. -void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); - -// Dump the KV cache view showing individual sequences in each cell (long output). -void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); - // // Embedding utils // diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d027271fc..2c55d2149 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,5 +1,8 @@ #include "cpy.cuh" #include "dequantize.cuh" +#ifdef GGML_USE_MUSA +#include "ggml-musa/mudnn.cuh" +#endif // GGML_USE_MUSA typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -597,7 +600,14 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg #endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); +#ifdef GGML_USE_MUSA + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { + CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); + } else +#endif // GGML_USE_MUSA + { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index be0329d0e..7120053b6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -772,7 +772,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); - GGML_UNUSED(kb0); + GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 921c52da6..cc218dd53 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ half maskh_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16( for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); + skip = skip && isinf(tmp.x) && isinf(tmp.y); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + continue; + } +#endif // GGML_USE_HIP + } + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). @@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16( sum = logit_softcap*tanhf(sum); } - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += maskh_shared[j*D + i_KQ]; if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 7064675d5..49c592ea5 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, @@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ float maskf_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32( for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + skip = skip && isinf(maskf_shared[j*D + i]); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + continue; + } +#endif // GGML_USE_HIP + } + float kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32( sum = logit_softcap*tanhf(sum); } - sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + sum += maskf_shared[j*D + i_KQ]; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e94b6cd75..f18473dcb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3255,7 +3255,7 @@ template< typename kd4x4_t, // key type in device memory short nl_k, void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), - typename vd4x4_t, // key type in device memory + typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size @@ -3776,7 +3776,7 @@ template< typename kd4_t, // key type in device memory short nl_k, void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // key type in device memory + typename vd4_t, // value type in device memory short nl_v, void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size diff --git a/ggml/src/ggml-musa/mudnn.cu b/ggml/src/ggml-musa/mudnn.cu new file mode 100644 index 000000000..020c1702c --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cu @@ -0,0 +1,112 @@ +#include +#include + +#include "mudnn.cuh" + +namespace mudnn = musa::dnn; + +// Returns a human-readable error string for mudnn::Status +const char* mudnnGetErrorString(mudnn::Status err) { + switch (err) { + case mudnn::Status::SUCCESS: + return "Success"; + case mudnn::Status::INVALID_PARAMETER: + return "Invalid parameter"; + case mudnn::Status::NOT_INITIALIZED: + return "Not initialized"; + case mudnn::Status::ALLOC_FAILED: + return "Allocation failed"; + case mudnn::Status::NOT_SUPPORTED: + return "Not supported"; + case mudnn::Status::INTERNAL_ERROR: + return "Internal error"; + case mudnn::Status::ARCH_MISMATCH: + return "Architecture mismatch"; + case mudnn::Status::EXECUTION_FAILED: + return "Execution failed"; + default: + return "Unknown mudnn status"; + } +} + +// Error checking macro for MUDNN calls +#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString) + +namespace { + // Thread-safe cache for mudnn::Handle objects per device + std::unordered_map> handle_cache; + std::mutex handle_cache_mutex; + + mudnn::Handle* get_cached_handle(int device_id) { + std::lock_guard lock(handle_cache_mutex); + auto it = handle_cache.find(device_id); + if (it != handle_cache.end()) { + return it->second.get(); + } + auto handle = std::make_unique(device_id); + mudnn::Handle* handle_ptr = handle.get(); + handle_cache[device_id] = std::move(handle); + return handle_ptr; + } +} + +// Extracts dimensions and strides from a ggml_tensor +int get_ggml_dims_and_strides(const ggml_tensor* tensor, + std::vector& dims, + std::vector& strides) { + const int ndims = ggml_n_dims(tensor); + const size_t element_size = ggml_element_size(tensor); + + dims.resize(ndims); + strides.resize(ndims); + + for (int i = 0; i < ndims; ++i) { + dims[i] = tensor->ne[i]; + strides[i] = tensor->nb[i] / static_cast(element_size); + } + return ndims; +} + +// Converts ggml_type to mudnn::Tensor::Type +mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return mudnn::Tensor::Type::FLOAT; + case GGML_TYPE_F16: + return mudnn::Tensor::Type::HALF; + + // TODO: Add support for other types + + default: + MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED); + } + + return mudnn::Tensor::Type::FLOAT; // Default fallback +} + +// Asynchronous memory copy using mudnn::Unary::IDENTITY +musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) { + mudnn::Tensor tensor_dst, tensor_src; + + MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type))); + MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type))); + + std::vector dims, strides; + const int ndims = get_ggml_dims_and_strides(src, dims, strides); + + MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_dst.SetAddr(dst->data)); + MUDNN_CHECK(tensor_src.SetAddr(src->data)); + + mudnn::Unary op; + MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY)); + MUDNN_CHECK(op.SetAlpha(0.0f)); + MUDNN_CHECK(op.SetBeta(0.0f)); + + mudnn::Handle* handle = get_cached_handle(ctx.device); + MUDNN_CHECK(handle->SetStream(ctx.stream())); + MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src)); + + return musaSuccess; +} diff --git a/ggml/src/ggml-musa/mudnn.cuh b/ggml/src/ggml-musa/mudnn.cuh new file mode 100644 index 000000000..a63be5755 --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "../include/ggml.h" +#include "../ggml-cuda/common.cuh" + +// Asynchronously copies data from src tensor to dst tensor using the provided context. +// Returns a musaError_t indicating success or failure. +musaError_t mudnnMemcpyAsync( + ggml_backend_cuda_context &ctx, + const ggml_tensor *dst, + const ggml_tensor *src +); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8db90b965..ea3782af5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4537,6 +4537,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index 39184ef58..b604c1881 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -1,6 +1,6 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #include "dequant_head.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 7859a1a60..26163b167 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -7,7 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif #if defined(DATA_A_IQ1_M) -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif #if defined(DATA_A_BF16) && defined(COOPMAT) diff --git a/include/llama.h b/include/llama.h index 78c1d612e..8437a301c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -610,72 +610,13 @@ extern "C" { // KV cache // - // TODO: start using struct llama_kv_cache - - // Information associated with an individual cell in the KV cache view. - struct llama_kv_cache_view_cell { - // The position for this cell. Takes KV cache shifts into account. - // May be negative if the cell is not populated. - llama_pos pos; - }; - - // An updateable view of the KV cache. - struct llama_kv_cache_view { - // Number of KV cache cells. This will be the same as the context size. - int32_t n_cells; - - // Maximum number of sequences that can exist in a cell. It's not an error - // if there are more sequences in a cell than this value, however they will - // not be visible in the view cells_sequences. - int32_t n_seq_max; - - // Number of tokens in the cache. For example, if there are two populated - // cells, the first with 1 sequence id in it and the second with 2 sequence - // ids then you'll have 3 tokens. - int32_t token_count; - - // Number of populated cache cells. - int32_t used_cells; - - // Maximum contiguous empty slots in the cache. - int32_t max_contiguous; - - // Index to the start of the max_contiguous slot range. Can be negative - // when cache is full. - int32_t max_contiguous_idx; - - // Information for an individual cell. - struct llama_kv_cache_view_cell * cells; - - // The sequences for each cell. There will be n_seq_max items per cell. - llama_seq_id * cells_sequences; - }; - - // Create an empty KV cache view. (use only for debugging purposes) - LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); - - // Free a KV cache view. (use only for debugging purposes) - LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); - - // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) - // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx) - LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); - - /// - // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx); - DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx), - "use llama_kv_self_n_tokens instead"); - // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx); - DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx), - "use llama_kv_self_used_cells instead"); - // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( struct llama_context * ctx); @@ -758,61 +699,6 @@ extern "C" { // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_self_update(struct llama_context * ctx); - DEPRECATED(LLAMA_API void llama_kv_cache_clear( - struct llama_context * ctx), - "use llama_kv_self_clear instead"); - - DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_rm instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_cp instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_kv_self_seq_keep instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "use llama_kv_self_seq_add instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "use llama_kv_self_seq_div instead"); - - DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_kv_self_seq_pos_max instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx), - "use llama_kv_self_defrag instead"); - - DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx), - "use llama_kv_self_can_shift instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx), - "use llama_kv_self_update instead"); - - // // State / sessions // diff --git a/otherarch/embeddings_adapter.cpp b/otherarch/embeddings_adapter.cpp index 6797a3d86..9ea7e37d5 100644 --- a/otherarch/embeddings_adapter.cpp +++ b/otherarch/embeddings_adapter.cpp @@ -149,7 +149,7 @@ bool embeddingstype_load_model(const embeddings_load_model_inputs inputs) } std::vector tmp = {1, 2, 3, 4}; - llama_kv_cache_clear(embeddings_ctx); + llama_kv_self_clear(embeddings_ctx); auto er = llama_decode(embeddings_ctx, llama_batch_get_one(tmp.data(), tmp.size())); if(er!=0) { @@ -190,7 +190,7 @@ embeddings_generation_outputs embeddingstype_generate(const embeddings_generatio double timetaken = 0; timer_start(); - llama_kv_cache_clear(embeddings_ctx); + llama_kv_self_clear(embeddings_ctx); std::string prompt = inputs.prompt; // max batch size diff --git a/otherarch/tts_adapter.cpp b/otherarch/tts_adapter.cpp index 57e50b319..2f7a676e3 100644 --- a/otherarch/tts_adapter.cpp +++ b/otherarch/tts_adapter.cpp @@ -559,7 +559,7 @@ bool ttstype_load_model(const tts_load_model_inputs inputs) } std::vector tmp = {1, 2, 3, 4}; - llama_kv_cache_clear(ttc_ctx); + llama_kv_self_clear(ttc_ctx); auto er = llama_decode(ttc_ctx, llama_batch_get_one(tmp.data(), tmp.size())); if(er!=0) { @@ -618,8 +618,8 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) const std::string sampletext = (custom_speaker_text=="")?process_text("but that is what it is",ttsver):process_text(custom_speaker_text,ttsver); // process prompt and generate voice codes - llama_kv_cache_clear(ttc_ctx); - llama_kv_cache_clear(cts_ctx); + llama_kv_self_clear(ttc_ctx); + llama_kv_self_clear(cts_ctx); std::vector prompt_inp; prompt_init(prompt_inp, ttcvocab); @@ -817,7 +817,7 @@ tts_generation_outputs ttstype_generate(const tts_generation_inputs inputs) } } guide_tokens.clear(); - llama_kv_cache_clear(ttc_ctx); + llama_kv_self_clear(ttc_ctx); prompt_init(prompt_inp, ttcvocab); next_token_uses_guide_token = true; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7bf729f47..a1dec63aa 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2288,39 +2288,10 @@ int32_t llama_apply_adapter_cvec( return res ? 0 : -1; } -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); - return {}; - } - - return llama_kv_cache_view_init(*kv, n_seq_max); -} - -void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); - return; - } - - llama_kv_cache_view_update(view, kv); -} - // // kv cache // -// deprecated -int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { - return llama_kv_self_n_tokens(ctx); -} - int32_t llama_kv_self_n_tokens(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2330,11 +2301,6 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) { return kv->get_n_tokens(); } -// deprecated -int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { - return llama_kv_self_used_cells(ctx); -} - int32_t llama_kv_self_used_cells(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2344,11 +2310,6 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) { return kv->get_used_cells(); } -// deprecated -void llama_kv_cache_clear(llama_context * ctx) { - llama_kv_self_clear(ctx); -} - void llama_kv_self_clear(llama_context * ctx) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2358,15 +2319,6 @@ void llama_kv_self_clear(llama_context * ctx) { kv->clear(); } -// deprecated -bool llama_kv_cache_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); -} - bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, @@ -2380,16 +2332,6 @@ bool llama_kv_self_seq_rm( return kv->seq_rm(seq_id, p0, p1); } -// deprecated -void llama_kv_cache_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, @@ -2404,13 +2346,6 @@ void llama_kv_self_seq_cp( kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -// deprecated -void llama_kv_cache_seq_keep( - llama_context * ctx, - llama_seq_id seq_id) { - llama_kv_self_seq_keep(ctx, seq_id); -} - void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2420,16 +2355,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { kv->seq_keep(seq_id); } -// deprecated -void llama_kv_cache_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); -} - void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, @@ -2444,16 +2369,6 @@ void llama_kv_self_seq_add( kv->seq_add(seq_id, p0, p1, delta); } -// deprecated -void llama_kv_cache_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); -} - void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, @@ -2477,11 +2392,6 @@ llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { return kv->seq_pos_min(seq_id); } -// deprecated -llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_pos_max(ctx, seq_id); -} - llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { const auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2491,11 +2401,6 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { return kv->seq_pos_max(seq_id); } -// deprecated -void llama_kv_cache_defrag(llama_context * ctx) { - llama_kv_self_defrag(ctx); -} - void llama_kv_self_defrag(llama_context * ctx) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2506,11 +2411,6 @@ void llama_kv_self_defrag(llama_context * ctx) { kv->defrag_sched(-1.0f); } -// deprecated -bool llama_kv_cache_can_shift(const llama_context * ctx) { - return llama_kv_self_can_shift(ctx); -} - bool llama_kv_self_can_shift(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2520,11 +2420,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) { return kv->get_can_shift(); } -// deprecated -void llama_kv_cache_update(llama_context * ctx) { - llama_kv_self_update(ctx); -} - // llama state API // deprecated diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 71d416c9f..8c9e7080e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1368,6 +1368,10 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } } if (wo_b) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index aedc4a6e0..ca7390280 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2888,38 +2888,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) { - llama_kv_cache_view result = { - /*.n_cells = */ 0, - /*.n_seq_max = */ n_seq_max, - /*.token_count = */ 0, - /*.used_cells = */ kv.get_used_cells(), - /*.max_contiguous = */ 0, - /*.max_contiguous_idx = */ -1, - /*.cells = */ nullptr, - /*.cells_sequences = */ nullptr, - }; - - return result; -} - -void llama_kv_cache_view_free(llama_kv_cache_view * view) { - if (view->cells != nullptr) { - free(view->cells); - view->cells = nullptr; - } - if (view->cells_sequences != nullptr) { - free(view->cells_sequences); - view->cells_sequences = nullptr; - } -} - -void llama_kv_cache_view_update(llama_kv_cache_view * , const llama_kv_cache * ) { - // TODO: will be removed soon, keep this for now to avoid too many changes in - // https://github.com/ggml-org/llama.cpp/pull/13194 -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 256a7d43e..bd0485bc6 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -534,12 +534,3 @@ private: bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; - - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max); - -void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 19f0c4ed8..295d4ca18 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4903,8 +4903,21 @@ struct llm_build_llama_iswa : public llm_graph_context { ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - { - // llama4 MoE + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 7a3288672..bb20db150 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -231,12 +231,14 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, while (i < n_tokens) { // split into batches text_batch.n_tokens = 0; // clear the batch for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { + int32_t j = text_batch.n_tokens; + text_batch.token [j] = tokens[i]; + text_batch.pos [j] = n_past++; + text_batch.n_seq_id[j] = 1; + text_batch.seq_id [j][0] = seq_id; + text_batch.logits [j] = false; + text_batch.n_tokens++; - text_batch.token [i] = tokens[i]; - text_batch.pos [i] = n_past++; - text_batch.n_seq_id[i] = 1; - text_batch.seq_id [i][0] = seq_id; - text_batch.logits [i] = false; } bool is_last_token = (i == n_tokens); if (logits_last && is_last_token) {