From 27ebfcacbaadc6104e2b18acd8f13515cbf63dce Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 9 May 2025 13:02:07 +0200 Subject: [PATCH 01/33] llama : do not crash if there is no CPU backend (#13395) * llama : do not crash if there is no CPU backend * add checks to examples --- src/llama-adapter.cpp | 6 ++++++ src/llama-model-loader.cpp | 4 ++++ src/llama-model.cpp | 13 +++++++++++++ tools/main/main.cpp | 7 ++++++- tools/mtmd/clip.cpp | 12 ++++++++---- tools/mtmd/llava.cpp | 1 + tools/rpc/rpc-server.cpp | 18 ++++++++++-------- 7 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index 7ac54d239..8d94034ae 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ std::vector buft_extra; { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) @@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } buft = ggml_backend_dev_buffer_type(cpu_dev); break; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index ea73a8a7b..1c8bce385 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mmaps_used.reserve(files.size()); for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + if (!reg) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e8b78c1d0..21b12339a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -299,6 +299,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // add extra buffer types, only if no GPU device is present // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); @@ -1484,6 +1488,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { @@ -1672,6 +1679,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto * buft_dev = ggml_backend_buft_get_device(buft); if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } buft = ggml_backend_dev_buffer_type(cpu_dev); } @@ -4122,6 +4132,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!dev) { // FIXME: workaround for CPU backend buft having a NULL device dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } } ggml_backend_dev_props props; ggml_backend_dev_get_props(dev, &props); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 756297c25..1bd2be2d9 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -152,7 +152,12 @@ int main(int argc, char ** argv) { LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + LOG_ERR("%s: no CPU backend found\n", __func__); + return 1; + } + auto * reg = ggml_backend_dev_backend_reg(cpu_dev); auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new"); auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free"); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 4e1a73287..1a81c1fcd 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -352,9 +352,12 @@ struct clip_ctx { clip_ctx(clip_context_params & ctx_params) { backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - backend = ctx_params.use_gpu - ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) - : nullptr; + if (!backend_cpu) { + throw std::runtime_error("failed to initialize CPU backend"); + } + backend = ctx_params.use_gpu + ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) + : nullptr; if (backend) { LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend)); @@ -2185,9 +2188,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity) { struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { g_logger_state.verbosity_thold = ctx_params.verbosity; - clip_ctx * ctx_clip = new clip_ctx(ctx_params); + clip_ctx * ctx_clip = nullptr; try { + ctx_clip = new clip_ctx(ctx_params); clip_model_loader loader(fname, *ctx_clip); loader.load_hparams(); loader.load_tensors(); diff --git a/tools/mtmd/llava.cpp b/tools/mtmd/llava.cpp index b85ab112b..ebef8b3c1 100644 --- a/tools/mtmd/llava.cpp +++ b/tools/mtmd/llava.cpp @@ -212,6 +212,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector ggml_build_forward_expand(gf, flatten); ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend"); ggml_backend_graph_compute(backend.get(), gf); struct ggml_tensor* result = ggml_graph_node(gf, -1); diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index a3f901a22..581c74018 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -237,15 +237,17 @@ static ggml_backend_t create_backend(const rpc_server_params & params) { backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } - fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); + if (backend) { + fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); - // set the number of threads - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - if (reg) { - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - ggml_backend_set_n_threads_fn(backend, params.n_threads); + // set the number of threads + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, params.n_threads); + } } } From 0cf6725e9f9a164c39f7a87214d60342f7f946d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 9 May 2025 13:34:58 +0200 Subject: [PATCH 02/33] CUDA: FA support for Deepseek (Ampere or newer) (#13306) * CUDA: FA support for Deepseek (Ampere or newer) * do loop unrolling via C++ template --- ggml/src/ggml-cuda/CMakeLists.txt | 2 +- ggml/src/ggml-cuda/common.cuh | 19 + ggml/src/ggml-cuda/cp-async.cuh | 11 + ggml/src/ggml-cuda/fattn-common.cuh | 26 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 960 +++++++++++------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 4 +- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 2 +- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 2 +- ggml/src/ggml-cuda/fattn.cu | 116 ++- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +- ...ttn-mma-f16-instance-ncols1_1-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 12 +- ...ttn-mma-f16-instance-ncols1_2-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 12 +- ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 12 +- ...ttn-mma-f16-instance-ncols1_4-ncols2_16.cu | 5 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 12 +- ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 12 +- ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 12 +- .../template-instances/generate_cu_files.py | 21 +- src/llama-graph.cpp | 11 + 33 files changed, 852 insertions(+), 547 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 969a178f6..c9ff4aa32 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND) set(CUDA_CXX_FLAGS "") - set(CUDA_FLAGS -use_fast_math) + set(CUDA_FLAGS -use_fast_math -extended-lambda) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") # Options are: diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 919217dfa..64fb4ff4c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -296,6 +296,25 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template +struct ggml_cuda_unroll { + template + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index ecb659997..63d0c482f 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -2,6 +2,17 @@ #include "common.cuh" + +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + // Copies data from global to shared memory, cg == cache global. // Both the src and dst pointers must be aligned to 16 bit. // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c7dc72882..b7180d595 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { @@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE @@ -691,7 +691,7 @@ void launch_fattn( GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); @@ -754,10 +754,13 @@ void launch_fattn( const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; + const int max_blocks = max_blocks_per_sm*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); @@ -769,14 +772,11 @@ void launch_fattn( blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); @@ -853,19 +853,19 @@ void launch_fattn( if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } } else if (parallel_blocks > 1) { - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 04804a15c..2b6bdc30c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,104 +13,217 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; -template +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. + +template +struct fattn_mma_f16_config; + +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 32; + static constexpr int nbatch_V2 = 32; + static constexpr int nbatch_combine = 32; +}; + +template <> +struct fattn_mma_f16_config< 80, 80> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 40; + static constexpr int nbatch_V2 = 40; + static constexpr int nbatch_combine = 40; +}; + +template <> +struct fattn_mma_f16_config< 96, 96> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 48; + static constexpr int nbatch_V2 = 48; + static constexpr int nbatch_combine = 48; +}; + +template <> +struct fattn_mma_f16_config<112, 112> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 56; + static constexpr int nbatch_V2 = 56; + static constexpr int nbatch_combine = 56; +}; + +template <> +struct fattn_mma_f16_config<128, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 64; + static constexpr int nbatch_V2 = 64; + static constexpr int nbatch_combine = 64; +}; + +template <> +struct fattn_mma_f16_config<256, 256> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + static constexpr int nbatch_K2 = 128; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + static constexpr int nbatch_K2 = 160; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +// ------------------------------------------------------------------------------------------------------------------ + +template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - // If cp.async is available, load up to the highest power of 2 in D asynchronously: -#ifdef CP_ASYNC_AVAILABLE - static_assert(D >= 64 && D < 512, "bad D"); - constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); - - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); - - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; - constexpr int stride_i = WARP_SIZE / chunks_per_row; -#pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); - const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; - - cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); - } -#else - constexpr int k0_sync_start = 0; -#endif // CP_ASYNC_AVAILABLE - static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); - - // If D is not a power of 2, the rest is loaded synchronously. // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - if (k0_start == k0_stop || k0_stop <= k0_sync_start) { - continue; - } + if (use_cp_async) { + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; -#pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); -#pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + if (k0_start == k0_stop) { + return; } - } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + }; + ggml_cuda_unroll<5>{}(load); + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } + }; + ggml_cuda_unroll<3>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); -#ifdef CP_ASYNC_AVAILABLE - constexpr int preload = KQ_per_iter * sizeof(half); - constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + if (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; + } + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); if (j0 + stride_j > ncols1 && j >= ncols1) { break; } - const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); - cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; } -#else - constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; -#pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); - - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; - } - - const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); - - tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; - } -#endif // CP_ASYNC_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -123,9 +236,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, + half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, half2 * const __restrict__ tile_mask, @@ -135,59 +250,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - const int k_VKQ_0 = kb0 * KQ_per_iter; - tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; // Use wide variants of tiles if ntiles >= 2. tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; -#ifdef CP_ASYNC_AVAILABLE - cp_async_wait_all(); - __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); -#else - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); - } - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { - tile_A K_A; - load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } - } + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE +#pragma unroll + for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { + const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } if (use_logit_softcap) { - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -205,7 +368,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { @@ -213,16 +376,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); @@ -238,10 +401,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); - + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); @@ -252,7 +414,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } else { // ntiles > 1 if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; #pragma unroll for (int t = 0; t < ntiles/2; ++t) { @@ -261,7 +423,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; @@ -272,9 +434,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -294,9 +456,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -325,7 +487,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { #pragma unroll for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; @@ -336,7 +498,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { #pragma unroll for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; @@ -347,16 +509,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); @@ -364,52 +526,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } -#ifdef CP_ASYNC_AVAILABLE - // Preload K tile for next iteration: - cp_async_wait_all(); - __syncthreads(); - if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } -#else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate VKQ tile: #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); -#pragma unroll - for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { + const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; + const int i0_diff = i0_stop - i0_start; - tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); - } else { + if (nstages == 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate VKQ tile: #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } } } } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } } - -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE - #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); @@ -419,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -434,7 +611,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int ne02, const int stride_Q1, const int stride_Q2, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, const int kb0_start, @@ -442,6 +620,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; @@ -449,22 +635,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; - // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: - extern __shared__ half2 tile_K[]; -#ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_per_iter*D2_padded; -#else - half2 * tile_V = tile_K; -#endif // CP_ASYNC_AVAILABLE - half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; - tile_B Q_B[D/(2*tile_B::J) * ntiles]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; @@ -476,13 +659,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_max[col] = -FLT_MAX/2.0f; } - // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { @@ -506,14 +690,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; - tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } } } @@ -521,18 +705,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - { + if (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); } } } @@ -540,35 +724,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - // Preload mask and K data for first iteration when using cp_async: -#ifdef CP_ASYNC_AVAILABLE - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); -#endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } - // With cp_async there is no __syncthreads at the end of the iter, + // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. -#ifdef CP_ASYNC_AVAILABLE - if (nwarps*cols_per_warp > KQ_per_iter) { + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } -#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. @@ -584,38 +770,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - // Write VKQ accumulators to shared memory in column-major format. - // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. - // Also for np > 1 the combination is done via these values in shared memory. - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. -#pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); - - tile_K[jc_cwd*D2_padded + k] = B.x[l]; - } - } - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); - - tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } - } - } - } + constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); if constexpr (ntiles == 1) { const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset @@ -624,7 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -649,7 +810,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -676,11 +837,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); - float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -690,10 +851,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } float KQ_cms[nmeta]; // KQ combine max scale per warp. @@ -709,10 +869,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: @@ -720,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int imeta = 0; imeta < nmeta; ++imeta) { if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } @@ -736,88 +895,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - if (np > 1) { - __syncthreads(); - } - - if (np == 1 || threadIdx.y % np == 0) { - // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. - // The values after that are for the partial results of the individual blocks. - float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - if (k0_start == k0_stop) { - continue; + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } } - + } else { #pragma unroll - for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - - if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { - break; - } - - const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; - - const int j_dst = jc_dst / ncols2; - const int c_dst = jc_dst % ncols2; - - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { - continue; - } - - const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - float2 dstk_val = make_float2(0.0f, 0.0f); + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { #pragma unroll - for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); - dstk_val.x += dstk_val_add.x*KQ_crs; - dstk_val.y += dstk_val_add.y*KQ_crs; - } + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); - if (!needs_fixup && !is_fixup) { - const float KQ_rowsum_j = meta_j[1]; - dstk_val.x /= KQ_rowsum_j; - dstk_val.y /= KQ_rowsum_j; - } - - if (is_fixup) { - dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; - } else { - dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; } } } } - } - if (np > 1) { __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } + } + } + } + } + if (np > 1) { + __syncthreads(); + } } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -template -__launch_bounds__(nwarps*WARP_SIZE, 2) +template +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -857,24 +1046,27 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } - static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); + const int stride_K = nb11 / sizeof(half2); + const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; @@ -893,9 +1085,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -905,14 +1097,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -931,9 +1123,9 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -942,9 +1134,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -960,28 +1152,42 @@ static __global__ void flash_attn_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int ncols = ncols1 * ncols2; - constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; - constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; - constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); - constexpr int cols_per_warp = ntiles * tile_B::I; + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; - static_assert(D % tile_B::J == 0, "bad D"); + typedef fattn_mma_f16_config c; + + constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; + constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; + constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + + constexpr int ncols = ncols1 * ncols2; + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const ggml_tensor * KQV = dst; - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); - const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); - const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); - - const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -989,59 +1195,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ - template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) -// Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. +// The number of viable configurations for Deepseek is very limited: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index e0039e175..9283560d5 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; case 128: { @@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index fcb6f848f..32673adb5 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; case 128: { @@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e17d2d0e4..ef0addc1d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d42ddca49..7064675d5 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index bc21b27a0..c5668adb1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7a2d1e453..9c5c803d0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,58 +8,32 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } if (Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); } -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const float use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; if (use_gqa_opt && gqa_ratio % 8 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 4) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 2) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } break; + default: + GGML_ABORT("fatal error"); + break; + } } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 42302e4ec..7643c4b8b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet - return false; + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + return false; + } + const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; + return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } if (op->src[0]->ne[0] == 192) { return false; } - if (op->src[0]->ne[0] == 576) { - // DeepSeek MLA - return false; - } if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu new file mode 100644 index 000000000..fb26abeb0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 80108615a..dc1682902 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 1, 8); -DECL_FATTN_MMA_F16_CASE(80, 1, 8); -DECL_FATTN_MMA_F16_CASE(96, 1, 8); -DECL_FATTN_MMA_F16_CASE(112, 1, 8); -DECL_FATTN_MMA_F16_CASE(128, 1, 8); -DECL_FATTN_MMA_F16_CASE(256, 1, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu index 66161c0ab..9d3cfd8ed 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 1); -DECL_FATTN_MMA_F16_CASE(80, 16, 1); -DECL_FATTN_MMA_F16_CASE(96, 16, 1); -DECL_FATTN_MMA_F16_CASE(112, 16, 1); -DECL_FATTN_MMA_F16_CASE(128, 16, 1); -DECL_FATTN_MMA_F16_CASE(256, 16, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu index ee88c72aa..2e1883af4 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 2); -DECL_FATTN_MMA_F16_CASE(80, 16, 2); -DECL_FATTN_MMA_F16_CASE(96, 16, 2); -DECL_FATTN_MMA_F16_CASE(112, 16, 2); -DECL_FATTN_MMA_F16_CASE(128, 16, 2); -DECL_FATTN_MMA_F16_CASE(256, 16, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index d888a5a42..2074e954a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 4); -DECL_FATTN_MMA_F16_CASE(80, 16, 4); -DECL_FATTN_MMA_F16_CASE(96, 16, 4); -DECL_FATTN_MMA_F16_CASE(112, 16, 4); -DECL_FATTN_MMA_F16_CASE(128, 16, 4); -DECL_FATTN_MMA_F16_CASE(256, 16, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu new file mode 100644 index 000000000..f011a208c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index d93a2d08e..24c64cf00 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 4); -DECL_FATTN_MMA_F16_CASE(80, 2, 4); -DECL_FATTN_MMA_F16_CASE(96, 2, 4); -DECL_FATTN_MMA_F16_CASE(112, 2, 4); -DECL_FATTN_MMA_F16_CASE(128, 2, 4); -DECL_FATTN_MMA_F16_CASE(256, 2, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 617464c94..163b1d939 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 8); -DECL_FATTN_MMA_F16_CASE(80, 2, 8); -DECL_FATTN_MMA_F16_CASE(96, 2, 8); -DECL_FATTN_MMA_F16_CASE(112, 2, 8); -DECL_FATTN_MMA_F16_CASE(128, 2, 8); -DECL_FATTN_MMA_F16_CASE(256, 2, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu index 970d2b686..0543532ea 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 1); -DECL_FATTN_MMA_F16_CASE(80, 32, 1); -DECL_FATTN_MMA_F16_CASE(96, 32, 1); -DECL_FATTN_MMA_F16_CASE(112, 32, 1); -DECL_FATTN_MMA_F16_CASE(128, 32, 1); -DECL_FATTN_MMA_F16_CASE(256, 32, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu index 65cd377c3..407b6cf4c 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 2); -DECL_FATTN_MMA_F16_CASE(80, 32, 2); -DECL_FATTN_MMA_F16_CASE(96, 32, 2); -DECL_FATTN_MMA_F16_CASE(112, 32, 2); -DECL_FATTN_MMA_F16_CASE(128, 32, 2); -DECL_FATTN_MMA_F16_CASE(256, 32, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu new file mode 100644 index 000000000..f5fd0e236 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu index f4a8bf348..5e4668502 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 2); -DECL_FATTN_MMA_F16_CASE(80, 4, 2); -DECL_FATTN_MMA_F16_CASE(96, 4, 2); -DECL_FATTN_MMA_F16_CASE(112, 4, 2); -DECL_FATTN_MMA_F16_CASE(128, 4, 2); -DECL_FATTN_MMA_F16_CASE(256, 4, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index de191a8ab..1ada657f1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 4); -DECL_FATTN_MMA_F16_CASE(80, 4, 4); -DECL_FATTN_MMA_F16_CASE(96, 4, 4); -DECL_FATTN_MMA_F16_CASE(112, 4, 4); -DECL_FATTN_MMA_F16_CASE(128, 4, 4); -DECL_FATTN_MMA_F16_CASE(256, 4, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index e8cb0e1b3..bad296b41 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 8); -DECL_FATTN_MMA_F16_CASE(80, 4, 8); -DECL_FATTN_MMA_F16_CASE(96, 4, 8); -DECL_FATTN_MMA_F16_CASE(112, 4, 8); -DECL_FATTN_MMA_F16_CASE(128, 4, 8); -DECL_FATTN_MMA_F16_CASE(256, 4, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu index a532e9629..0d7a9c728 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 1); -DECL_FATTN_MMA_F16_CASE(80, 64, 1); -DECL_FATTN_MMA_F16_CASE(96, 64, 1); -DECL_FATTN_MMA_F16_CASE(112, 64, 1); -DECL_FATTN_MMA_F16_CASE(128, 64, 1); -DECL_FATTN_MMA_F16_CASE(256, 64, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu index bf25181aa..9d5a9976f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 1); -DECL_FATTN_MMA_F16_CASE(80, 8, 1); -DECL_FATTN_MMA_F16_CASE(96, 8, 1); -DECL_FATTN_MMA_F16_CASE(112, 8, 1); -DECL_FATTN_MMA_F16_CASE(128, 8, 1); -DECL_FATTN_MMA_F16_CASE(256, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu index 378c132e6..a6e6f093d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 2); -DECL_FATTN_MMA_F16_CASE(80, 8, 2); -DECL_FATTN_MMA_F16_CASE(96, 8, 2); -DECL_FATTN_MMA_F16_CASE(112, 8, 2); -DECL_FATTN_MMA_F16_CASE(128, 8, 2); -DECL_FATTN_MMA_F16_CASE(256, 8, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 372641be9..86d4ffae2 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 4); -DECL_FATTN_MMA_F16_CASE(80, 8, 4); -DECL_FATTN_MMA_F16_CASE(96, 8, 4); -DECL_FATTN_MMA_F16_CASE(112, 8, 4); -DECL_FATTN_MMA_F16_CASE(128, 8, 4); -DECL_FATTN_MMA_F16_CASE(256, 8, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 9ff5968b6..680a13ca6 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 8); -DECL_FATTN_MMA_F16_CASE(80, 8, 8); -DECL_FATTN_MMA_F16_CASE(96, 8, 8); -DECL_FATTN_MMA_F16_CASE(112, 8, 8); -DECL_FATTN_MMA_F16_CASE(128, 8, 8); -DECL_FATTN_MMA_F16_CASE(256, 8, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index dd373a09d..3428113dc 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,18 +57,21 @@ for vkq_size in [16, 32]: with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for ncols in [8, 16, 32, 64, 128]: - for ncols2 in [1, 2, 4, 8]: +for ncols in [8, 16, 32, 64]: + for ncols2 in [1, 2, 4, 8, 16]: + if ncols2 > ncols: + continue ncols1 = ncols // ncols2 - if ncols == 128: - continue # Too much register pressure. with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size in [64, 80, 96, 112, 128, 256]: - if ncols == 128 and head_size == 256: - continue # Needs too much shared memory. - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) + for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + if head_size_kq != 576 and ncols2 == 16: + continue + if head_size_kq == 576 and ncols2 != 16: + continue + head_size_v = head_size_kq if head_size_kq != 576 else 512 + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0b004a8ba..a8bb83cc5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1227,8 +1227,19 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); if (v_mla) { +#if 0 + // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); +#else + // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. + // The permutations are noops and only change how the tensor data is interpreted. + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. +#endif } cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); From 611aa914ef4231fab5d1ad04773c42e119ae2d2e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 9 May 2025 15:14:56 +0300 Subject: [PATCH 03/33] metal : optimize MoE for large batches (#13388) ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 43 ++- ggml/src/ggml-metal/ggml-metal.m | 331 +++++++++++++++-------- ggml/src/ggml-metal/ggml-metal.metal | 373 ++++++++++++++------------ ggml/src/ggml.c | 4 +- 4 files changed, 458 insertions(+), 293 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8721b272d..ea8cf4586 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -299,21 +299,42 @@ typedef struct { } ggml_metal_kargs_mul_mv_ext; typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; + int32_t ne10; + int32_t ne11; // n_expert_used (bcast) + uint64_t nb11; + uint64_t nb12; + int32_t neh11; // n_tokens + uint64_t nbh11; + int32_t ne20; // n_expert_used + uint64_t nb21; +} ggml_metal_kargs_mul_mm_id_map0; + +typedef struct { + int32_t ne20; // n_expert_used + int32_t neh0; + int32_t neh1; + uint64_t nbh1; + uint64_t nbh2; + int32_t ne0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_mul_mm_id_map1; + +typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; + uint64_t nb03; + int32_t neh12; + uint64_t nbh10; + uint64_t nbh11; + uint64_t nbh12; + uint64_t nbh13; + int32_t neh0; + int32_t neh1; + int16_t r2; + int16_t r3; } ggml_metal_kargs_mul_mm_id; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d92392edb..943f0fe72 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -306,28 +306,30 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -650,7 +652,8 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { } if (mem_pool->heaps_to_remove.count > 0) { - for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) { + // remove in reverse order + for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; @@ -659,6 +662,10 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { [mem_pool->heaps removeObjectAtIndex:index]; [ptr release]; + + if (i == 0) { + break; + } } [mem_pool->heaps_to_remove removeAllObjects]; @@ -672,7 +679,7 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { } static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { - const size_t alignment = 32; + const size_t alignment = 256; const size_t size_aligned = GGML_PAD(size, alignment); @@ -1242,28 +1249,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -2999,7 +3008,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { id pipeline = nil; @@ -3219,8 +3228,6 @@ static bool ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT_ID: { - const int n_as = src0->ne[2]; - // src2 = ids const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); @@ -3234,24 +3241,21 @@ static bool ggml_metal_encode_node( GGML_ASSERT(ne03 == 1); GGML_ASSERT(ne13 == 1); + const uint32_t r2 = 1; + const uint32_t r3 = 1; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - //GGML_ASSERT(dst_rows <= dst_rows_max); + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments - dst_rows > dst_rows_min && - dst_rows <= dst_rows_max) { + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -3262,62 +3266,169 @@ static bool ggml_metal_encode_node( default: break; } - id pipeline = nil; + const int64_t neh10 = ne10; // n_embd + const int64_t neh11 = ne21; // n_tokens + const int64_t neh12 = ne02; // n_expert - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); + const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); + const uint64_t nbh11 = nbh10*neh10; + const uint64_t nbh12 = nbh11*neh11; + const uint64_t nbh13 = nbh12*neh12; + + const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; + id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); + if (!h_src1) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); + return false; } - ggml_metal_kargs_mul_mm_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - }; + const int64_t neh0 = ne0; + const int64_t neh1 = ne21; + const int64_t neh2 = ne02; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); + const uint64_t nbh1 = nbh0*neh0; + const uint64_t nbh2 = nbh1*neh1; + //const uint64_t nbh3 = nbh2*neh2; - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; + id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); + return false; + } - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + // tokens per expert + const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; + id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); + if (!h_tpe) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); + return false; + } + + // id map + // [n_expert_used, n_tokens] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); + if (!h_ids) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); + return false; + } + + { + const int nth = MIN(1024, ne10/4); + + ggml_metal_kargs_mul_mm_id_map0 args = { + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + neh11, // n_tokens + nbh11, + ne20, // n_expert_used + nb21, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer: h_src1 offset:0 atIndex:3]; + [encoder setBuffer: h_tpe offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:5]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.neh12 =*/ neh12, + /*.nbh10 =*/ nbh10, + /*.nbh11 =*/ nbh11, + /*.nbh12 =*/ nbh12, + /*.nbh13 =*/ nbh13, + /*.neh0 =*/ neh0, + /*.neh1 =*/ neh1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer: h_tpe offset:0 atIndex:3]; + [encoder setBuffer: h_dst offset:0 atIndex:4]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } + + { + GGML_ASSERT(ne0 % 4 == 0); + + const int nth = MIN(1024, ne0/4); + + ggml_metal_kargs_mul_mm_id_map1 args = { + ne20, // n_expert_used + neh0, + neh1, + nbh1, + nbh2, + ne0, + nb1, + nb2, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer: h_dst offset:0 atIndex:1]; + [encoder setBuffer: h_ids offset:0 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } else { id pipeline = nil; @@ -3511,7 +3622,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int64_t ne123 = dst_rows; + const int64_t ne123 = ne20*ne21; if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f4147e93..a808e79c2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm( } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -// TODO: this kernel needs to be reimplemented from scratch for better performance -template -void kernel_mul_mm_id_impl( - int32_t ne00, - int32_t ne02, - uint64_t nb01, - uint64_t nb02, - int32_t ne11, - int32_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int32_t ne0, - int32_t ne1, - int64_t ne0ne1, - device const char * src0, - device const char * src1, - threadgroup ushort2 * rowids, - device char * dst, - threadgroup char * shmem, +template +kernel void kernel_mul_mm_id_map0( + constant ggml_metal_kargs_mul_mm_id_map0 & args, + device const char * src1, + device const char * src2, + device char * hsrc1, + device char * htpe, + device char * hids, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ide = tgpig[0]; // expert id + + int n_all = 0; + + device int32_t * ids_i32 = (device int32_t *) (hids); + + for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens + device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + + for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used + if (src2_i32[i20] != ide) { + continue; + } + + device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); + device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + + for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { + hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + } + + if (tpitg.x == 0) { + ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + } + + ++n_all; + } + } + + if (tpitg.x == 0) { + device int32_t * tpe_i32 = (device int32_t *) (htpe); + tpe_i32[ide] = n_all; + } +} + +typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; + +template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + +template +kernel void kernel_mul_mm_id_map1( + constant ggml_metal_kargs_mul_mm_id_map1 & args, + device const char * hdst, + device const char * hids, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i20 = tgpig[0]; // used expert + const int i21 = tgpig[1]; // token + + device const int32_t * ids_i32 = (device const int32_t *) (hids); + device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + + const int id = ids_i32[i21*args.ne20 + i20]; + + const int ide = id / args.neh1; + const int idt = id % args.neh1; + + device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + + for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { + dst_f32x4[i0] = hdst_f32x4[i0]; + } +} + +typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; + +template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; + +template +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * tpe, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup half * sb = (threadgroup half *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; + const int im = tgpig.z; - if (r1*BLOCK_SIZE_N >= ne1) return; + device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + + const int neh1 = tpe_i32[im]; + + if (r1*BLOCK_SIZE_N >= neh1) { + return; + } // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; + simdgroup_T8x8 ma[4]; + simdgroup_half8x8 mb[2]; simdgroup_float8x8 mc[8]; - for (int i = 0; i < 8; i++){ + + for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } + short il = (tiitg % THREAD_PER_ROW); - ushort offset1 = il/nl; + const int i12 = im%args.neh12; + const int i13 = im/args.neh12; - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + device const half * y = (device const half *)(src1 + + args.nbh13*i13 + + args.nbh12*i12 + + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) + + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); - #pragma unroll(BLOCK_SIZE_K/8) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { #pragma unroll(4) - for (int i = 0; i < 4; i++) { + for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) - for (int i = 0; i < 2; i++) { + for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - #pragma unroll(8) - for (int i = 0; i < 8; i++){ + for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; } } - { + if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; - - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl( } } -template -kernel void kernel_mul_mm_id( - constant ggml_metal_kargs_mul_mm_id & args, - device const char * src0s, - device const char * src1, - device char * dst, - device const char * ids, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - - tgpig.z = 0; - - device const char * src0 = src0s + i02*args.nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); - - // TODO: parallelize this loop - int32_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { - for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; - if (id == i02) { - if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - } - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - args.ne00, - args.ne02, - args.nb01, - args.nb02, - args.ne11, - args.ne12, - args.nb10, - args.nb11, - args.nb12, - args.ne0, - _ne1, - (int64_t)args.ne0*args.ne1, - src0, - src1, - rowids, - dst, - shmem, - tgpig, - tiitg, - sgitg); -} - #define QK_NL 16 // @@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mat_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee4fe9f72..bc673292b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec( c = ggml_mul_mat_id(ctx, as, b, ids); as -> [cols, rows, n_expert] - ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] + ids -> [n_expert_used, n_tokens] (i32) c -> [rows, n_expert_used, n_tokens] - in b, n_experts_used can be broadcasted to match the n_expert_used of ids + in b, n_expert_used can be broadcasted to match the n_expert_used of ids c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids */ From 17512a94d636c4b6c1332370acb3e5af3ca70918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= Date: Fri, 9 May 2025 16:34:08 +0100 Subject: [PATCH 04/33] sycl : implementation of reordered Q4_0 MMVQ for Intel GPUs (#12858) * sycl : Implemented reorder Q4_0 mmvq Signed-off-by: Alberto Cabrera * sycl : Fixed mmvq being called when reorder is disabled * sycl : Improved comments in the quants header Signed-off-by: Alberto Cabrera * Use static_assert * safe_div -> ceil_div * Clarify qi comment * change the reorder tensor from init to execute OP * dbg * Undo changes to test-backend-ops * Refactor changes on top of q4_0 reorder fix * Missing Reverts * Refactored opt_for_reorder logic to simplify code path * Explicit inlining and unroll * Renamed mul_mat_algo enum for consistency --------- Signed-off-by: Alberto Cabrera Co-authored-by: romain.biessy --- ggml/src/ggml-sycl/backend.hpp | 17 ++- ggml/src/ggml-sycl/common.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 124 ++++++++++++---- ggml/src/ggml-sycl/mmvq.cpp | 247 ++++++++++++++++++++----------- ggml/src/ggml-sycl/quants.hpp | 61 ++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 71 +++++++-- 6 files changed, 385 insertions(+), 136 deletions(-) create mode 100644 ggml/src/ggml-sycl/quants.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index de814ef91..f78a36ddf 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -14,23 +14,24 @@ #define GGML_SYCL_BACKEND_HPP #include "binbcast.hpp" -#include "concat.hpp" #include "common.hpp" +#include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" +#include "element_wise.hpp" +#include "gla.hpp" +#include "im2col.hpp" #include "mmq.hpp" #include "mmvq.hpp" -#include "rope.hpp" #include "norm.hpp" +#include "outprod.hpp" +#include "quants.hpp" +#include "rope.hpp" #include "softmax.hpp" #include "tsembd.hpp" -#include "im2col.hpp" #include "wkv.hpp" -#include "outprod.hpp" -#include "element_wise.hpp" -#include "cpy.hpp" -#include "gla.hpp" -#endif // GGML_SYCL_BACKEND_HPP +#endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 69aad938e..60909dde7 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; +extern int g_ggml_sycl_prioritize_dmmv; #define GGML_SYCL_DEBUG(...) \ do { \ diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 68a26fa48..0ea729948 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -195,11 +196,13 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); @@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons std::exit(1); } +enum class mul_mat_algo { + DMMV = 0, + MMVQ = 1, + MUL_MAT_SYCL = 2, +}; + inline bool ggml_sycl_supports_mmq(enum ggml_type type) { // TODO: accuracy issues in MMQ GGML_UNUSED(type); return false; } +inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + static bool ggml_sycl_supports_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows, GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; + auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; stream->parallel_for( @@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { reorder_qw(data_device, ncols, nrows, size, 0, stream); } -/* -* This function could be called when the OP (mul_mat) function support reorder optimizition. -*/ -static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, - ggml_tensor * dst) { - if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT - ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. - dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - src0->type == GGML_TYPE_Q4_0 && - src1->ne[2]==1 && src1->ne[3]==1) { +static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { + return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT + ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. + dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. + dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; +} - ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; - if (!extra) return; //only happen in CI/UT permute case. - - if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder. - - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. +static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, + ggml_tensor * dst, mul_mat_algo mm_algorithm) { + if (!should_reorder_tensor(*ctx, dst)) { + return; } + + ggml_tensor_extra_gpu * extra = static_cast(src0->extra); + if (!extra || extra->optimized_feature.reorder) { + return; // Skip permutations and already reordered tensors + } + + switch (mm_algorithm) { + case mul_mat_algo::DMMV: + if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { + return; + } + break; + case mul_mat_algo::MMVQ: + if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { + return; + } + break; + case mul_mat_algo::MUL_MAT_SYCL: + if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { + return; + } + break; + } + + reorder_qw(src0, ctx->stream()); + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor int64_t min_compute_capability = INT_MAX; if (split) { - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = + (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; for (int id = 0; id < ggml_sycl_info().device_count; ++id) { // skip devices that are not going to do any work: @@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } } } else { - min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; + min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; } // check data types and tensor shapes for custom matrix multiplication kernels: @@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX + // mmvq path is faster in the CUDA backend. - if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda) + if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // TODO: Refactor and cleanup of mul mat dispatching. @@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); + constexpr bool convert_src1_to_q8_1 = false; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); } else if (use_mul_mat_vec_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + constexpr bool convert_src1_to_q8_1 = true; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); } else if (use_mul_mat_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + constexpr bool convert_src1_to_q8_1 = true; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); } else { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + constexpr bool convert_src1_to_q8_1 = false; + // MUL_MAT_SYCL supports reorder + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); } + GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b92ba2d6..3cade1a42 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,60 @@ #include "mmvq.hpp" + +#include "ggml.h" +#include "common.hpp" +#include "quants.hpp" #include "vecdotq.hpp" -#include + +template +static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + const block_q8_1 * y = (const block_q8_1 *) vy; + + float partial_sum = 0.0f; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; // x block index + // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits + const int bx_offset = block_type::get_block_offset(ibx); + const int d_offset = block_type::get_d_offset(nrows, ncols, ibx); + + // Y block index that aligns with ibx + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + // x block quant index when casting the quants to int + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs); + } + } + + auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>()); + + if (sg.leader()) { + dst[row] = sum; + } +} template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } } -static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, +static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { - - stream->submit([&](sycl::handler &cgh) { - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -916,93 +983,95 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } -void ggml_sycl_op_mul_mat_vec_q( - ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_col_size, - const dpct::queue_ptr &stream) { - +void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, + const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size, + const dpct::queue_ptr & stream) { const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); - const int64_t ne00 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - for (int i = 0; i < src1_ncols; i++) - { + for (int i = 0; i < src1_ncols; i++) { const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; - const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; - float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + default: + GGML_ABORT("fatal error"); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp new file mode 100644 index 000000000..a74e30526 --- /dev/null +++ b/ggml/src/ggml-sycl/quants.hpp @@ -0,0 +1,61 @@ +// +// MIT license +// Copyright (C) 2025 Codeplay Software Ltd. +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_QUANTS_HPP +#define GGML_SYCL_QUANTS_HPP + +#include "ggml-common.h" +#include "ggml.h" + +namespace ggml_sycl_reordered { + + +// The reordered block moves quants (qs) and scales(d) to two +// uniform regions of memory that is contiguous in the same tensor. +// What this means is that instead of having: +// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN] +// We have: +// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] +// +// Notes: out-of-bounds qs will run into d values +// Aligment relies on the allocated size of qs + +template struct block_q_t; + + +// qk number of weights / quants in a block +// qr number of weights in a byte (described as 'before dequantization') +// for quantization types that has low and high bits split, qr is calculated with +// using the lower bits, e.g for Q6 quants QR6 is 2 +// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field) +// See ggml-common.h to see how these are calculated +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK4_0; + static constexpr uint32_t qi = QI4_0; + static constexpr uint32_t qr = QR4_0; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + + static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half); + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + +} // namespace ggml_sycl_reordered + +#endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index c5942008a..cbf664fcf 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -14,8 +14,11 @@ #define GGML_SYCL_VECDOTQ_HPP #include "dpct/helper.hpp" +#include "ggml.h" +#include "quants.hpp" -typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs); static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -252,13 +255,60 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +template struct reorder_vec_dot_q_sycl { + static_assert(T != T, "ggml_type for reorder vecdot not implemented"); +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_0; + + using q4_0_block = ggml_sycl_reordered::block_q_t; + using q4_0_traits = typename q4_0_block::traits; + + __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, + const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); + int v[q4_0_traits::vdr_mmvq]; + int u[2 * q4_0_traits::vdr_mmvq]; + +#pragma unroll + + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_uint8(bq4_0, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); + } + + return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); + }; +}; + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, - const float &d4, - const sycl::half2 &ds8) { +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, + const sycl::half2 & ds8) { int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { @@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); } - const sycl::float2 ds8f = - ds8.convert(); + const sycl::float2 ds8f = ds8.convert(); // second part effectively subtracts 8 from each quant value return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); @@ -456,13 +505,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq, const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; #pragma unroll for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); } return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); From 33eff4024084d1f0c8441b79f7208a52fad79858 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 9 May 2025 19:29:37 +0200 Subject: [PATCH 05/33] server : vision support via libmtmd (#12898) * server : (experimental) vision support via libmtmd * mtmd : add more api around mtmd_image_tokens * mtmd : add more api around mtmd_image_tokens * mtmd : ability to calc image hash * shared_ptr for mtmd_image_tokens * move hash to user-define ID (fixed) * abstract out the batch management * small fix * refactor logic adding tokens to batch * implement hashing image * use FNV hash, now hash bitmap instead of file data * allow decoding image embedding to be split into batches * rm whitespace * disable some features when mtmd is on * fix --no-mmproj-offload * mtmd_context_params no timings * refactor server_inp to server_tokens * fix the failing test case * init * wip * working version * add mtmd::bitmaps * add test target * rm redundant define * test: mtmd_input_chunks_free * rm outdated comment * fix merging issue * explicitly create mtmd::input_chunks * mtmd_input_chunk_copy * add clone() * improve server_input struct * clip : fix confused naming ffn_up and ffn_down * rm ffn_i/o/g naming * rename n_embd, n_ff * small fix * no check n_ff * fix detokenize * add const to various places * add warning about breaking changes * add c api * helper: use mtmd_image_tokens_get_n_pos * fix ctx_shift * fix name shadowing * more strict condition * support remote image_url * remote image_url log * add CI test * do not log base64 * add "has_multimodal" to /props * remove dangling image * speculative: use slot.cache_tokens.insert * Apply suggestions from code review Co-authored-by: Georgi Gerganov * rm can_be_detokenized * on prmpt processing done, assert cache_tokens.size * handle_completions_impl returns void * adapt the new web ui * update docs and hot topics * rm assert * small fix (2) --------- Co-authored-by: Georgi Gerganov --- README.md | 3 +- common/arg.cpp | 2 +- docs/multimodal.md | 69 ++++ tools/mtmd/README.md | 33 +- tools/server/CMakeLists.txt | 3 +- tools/server/README.md | 12 + tools/server/server.cpp | 309 +++++++++++++---- tools/server/tests/unit/test_vision_api.py | 59 ++++ tools/server/tests/utils.py | 18 + tools/server/utils.hpp | 367 ++++++++++++++++++++- 10 files changed, 774 insertions(+), 101 deletions(-) create mode 100644 docs/multimodal.md create mode 100644 tools/server/tests/unit/test_vision_api.py diff --git a/README.md b/README.md index e0232478c..0401723ff 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics +- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md) - **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9) -- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated +- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated - VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode - Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim diff --git a/common/arg.cpp b/common/arg.cpp index 73a3cfe53..f67e0d96d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -40,7 +40,7 @@ using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_LLAVA, - // TODO: add LLAMA_EXAMPLE_SERVER when it's ready + LLAMA_EXAMPLE_SERVER, }; static std::string read_file(const std::string & fname) { diff --git a/docs/multimodal.md b/docs/multimodal.md new file mode 100644 index 000000000..efed473a3 --- /dev/null +++ b/docs/multimodal.md @@ -0,0 +1,69 @@ +# Multimodal + +llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature: +- [llama-mtmd-cli](../tools/mtmd/README.md) +- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API + +To enable it, can use use one of the 2 methods below: + +- Use `-hf` option with a [supported model](../../docs/multimodal.md) + - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` + - To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf` +- Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively + +By default, multimodal projector will be offloaded to GPU. To disable this, add `--no-mmproj-offload` + +For example: + +```sh +# simple usage with CLI +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF + +# simple usage with server +llama-server -hf ggml-org/gemma-3-4b-it-GGUF + +# using local file +llama-server -m gemma-3-4b-it-Q4_K_M.gguf --mmproj mmproj-gemma-3-4b-it-Q4_K_M.gguf + +# no GPU offload +llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload +``` + +## Pre-quantized models + +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. + +Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` + +NOTE: some models may require large context window, for example: `-c 8192` + +```sh +# Gemma 3 +(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-12b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-27b-it-GGUF + +# SmolVLM +(tool_name) -hf ggml-org/SmolVLM-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-256M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-500M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF + +# Pixtral 12B +(tool_name) -hf ggml-org/pixtral-12b-GGUF + +# Qwen 2 VL +(tool_name) -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF + +# Qwen 2.5 VL +(tool_name) -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF + +# Mistral Small 3.1 24B (IQ2_M quantization) +(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF +``` diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index 20e7696ce..06e1fd097 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -16,38 +16,7 @@ The naming and structure related to multimodal support have evolved, which might ## Pre-quantized models -These are ready-to-use models, most of them come with `Q4_K_M` quantization by default: - -```sh -# Gemma 3 -llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF -llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF -llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF - -# SmolVLM -llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF - -# Pixtral 12B -llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF - -# Qwen 2 VL -llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF - -# Qwen 2.5 VL -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF -llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF - -# Mistral Small 3.1 24B (IQ2_M quantization) -llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF -``` +See the list of pre-quantized model [here](../../docs/multimodal.md) ## How it works and what is `mmproj`? diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index aee90388e..17109fddb 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -34,8 +34,9 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) +target_include_directories(${TARGET} PRIVATE ../llava) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) diff --git a/tools/server/README.md b/tools/server/README.md index 0ec786ea7..972ca384e 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -193,6 +193,12 @@ services: LLAMA_ARG_PORT: 8080 ``` +### Multimodal support + +Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature. + +For more details, please refer to [multimodal documentation](../../docs/multimodal.md) + ## Build `llama-server` is built alongside everything else from the root of the project @@ -749,6 +755,9 @@ This endpoint is public (no API key check). By default, it is read-only. To make "total_slots": 1, "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "chat_template": "...", + "modalities": { + "vision": false + }, "build_info": "b(build number)-(build commit hash)" } ``` @@ -757,6 +766,7 @@ This endpoint is public (no API key check). By default, it is read-only. To make - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) - `model_path` - the path to model file (same with `-m` argument) - `chat_template` - the model's original Jinja2 prompt template +- `modalities` - the list of supported modalities ### POST `/props`: Change server global properties. @@ -1069,6 +1079,8 @@ print(completion.choices[0].text) Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. +If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more. + *Options:* See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported. diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 06788bbdc..de8ded71f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -7,6 +7,7 @@ #include "log.h" #include "sampling.h" #include "speculative.h" +#include "mtmd.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -197,8 +198,8 @@ struct server_task { int id_target = -1; // used by SERVER_TASK_TYPE_INFERENCE - slot_params params; - llama_tokens prompt_tokens; + slot_params params; + server_tokens prompt_tokens; int id_selected_slot = -1; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE @@ -1248,6 +1249,9 @@ struct server_slot { llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + common_speculative * spec = nullptr; std::vector lora; @@ -1275,14 +1279,14 @@ struct server_slot { int32_t n_prompt_tokens_processed = 0; // input prompt tokens - llama_tokens prompt_tokens; + server_tokens prompt_tokens; size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - llama_tokens cache_tokens; + server_tokens cache_tokens; std::vector generated_token_probs; @@ -1476,7 +1480,7 @@ struct server_slot { {"is_processing", is_processing()}, {"non_causal", is_non_causal()}, {"params", params.to_json()}, - {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"prompt", prompt_tokens.detokenize(ctx, true)}, {"next_token", { {"has_next_token", has_next_token}, @@ -1849,13 +1853,16 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + // multimodal + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; llama_model * model_dft = nullptr; llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1878,6 +1885,8 @@ struct server_context { common_chat_templates_ptr chat_templates; ~server_context() { + mtmd_free(mctx); + // Clear any sampling context for (server_slot & slot : slots) { common_sampler_free(slot.smpl); @@ -1965,6 +1974,36 @@ struct server_context { chat_templates = common_chat_templates_init(model, "chatml"); } + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (!params_base.speculative.model.path.empty()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + return true; } @@ -1980,6 +2019,8 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.mctx = mctx; + slot.cache_tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2016,8 +2057,6 @@ struct server_context { // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } @@ -2054,7 +2093,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -2096,18 +2135,6 @@ struct server_context { return ret; } - bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & token : tokens) { - if (token < 0 || token >= n_vocab) { - return false; - } - } - return true; - } - bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); slot.id_task = task.id; @@ -2122,8 +2149,7 @@ struct server_context { slot.lora = slot.params.lora; } - bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens); - if (!can_detokenize) { + if (!slot.prompt_tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -2385,6 +2411,15 @@ struct server_context { queue_results.send(std::move(res)); } + // if multimodal is enabled, send an error and return false + bool ensure_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { auto res = std::make_unique(); @@ -2424,7 +2459,7 @@ struct server_context { res->content = std::move(slot.generated_text); res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->prompt = slot.prompt_tokens.detokenize(ctx, true); res->response_fields = std::move(slot.params.response_fields); res->truncated = slot.truncated; @@ -2734,6 +2769,10 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { + if (!ensure_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2753,7 +2792,8 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); + const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; @@ -2770,6 +2810,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { + if (!ensure_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2788,15 +2829,18 @@ struct server_context { std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - slot->cache_tokens.resize(slot->n_ctx); + llama_tokens tokens; + tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.resize(0); + slot->cache_tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } - slot->cache_tokens.resize(token_count); + tokens.resize(token_count); + slot->cache_tokens.clear(); + slot->cache_tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -2813,6 +2857,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { + if (!ensure_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -2844,6 +2889,7 @@ struct server_context { res->id = task.id; queue_results.send(std::move(res)); } break; + } } @@ -2889,6 +2935,12 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = slot.n_past - n_keep; @@ -2900,11 +2952,14 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.clear(); + slot.cache_tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -2982,7 +3037,7 @@ struct server_context { SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); // print prompt tokens (for debugging) - if (1) { + /*if (1) { // first 16 tokens (avoid flooding logs) for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); @@ -2992,7 +3047,7 @@ struct server_context { for (int i = 0; i < (int) prompt_tokens.size(); i++) { SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); } - } + }*/ // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { @@ -3034,21 +3089,27 @@ struct server_context { // if input prompt is too big, truncate it if (slot.n_prompt_tokens >= slot.n_ctx) { + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); llama_tokens new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + curr_tokens.begin(), + curr_tokens.begin() + slot.params.n_keep); new_tokens.insert( new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + curr_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens.clear(); + prompt_tokens.insert(new_tokens); slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); @@ -3060,13 +3121,18 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache size_t head_p = slot.n_past; // current prompt + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); while (head_c < slot.cache_tokens.size() && @@ -3092,7 +3158,7 @@ struct server_context { llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); slot.n_past++; } @@ -3140,21 +3206,52 @@ struct server_context { // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); + // check if we should process the image + if (slot.n_past < slot.n_prompt_tokens + && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + // process the image + int32_t new_n_past; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + int32_t n_pos = new_n_past - slot.n_past; + + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + slot.release(); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + continue; + } + + if (slot.params.cache_prompt) { + const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); + slot.cache_tokens.push_back(chunk.get()); // copy + } + + slot.n_past += n_pos; + slot.n_prompt_tokens_processed += n_pos; + } + // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); - + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); + slot.cache_tokens.push_back(cur_tok); } slot.n_prompt_tokens_processed++; slot.n_past++; } + // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed @@ -3162,12 +3259,16 @@ struct server_context { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, prompt_tokens[i], false); + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } } // extract the logits only for the last token @@ -3320,6 +3421,11 @@ struct server_context { continue; } + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; @@ -3346,7 +3452,8 @@ struct server_context { params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // keep track of total number of tokens generated in the draft slot.n_draft_total += draft.size(); @@ -3380,7 +3487,7 @@ struct server_context { slot.n_draft_accepted += ids.size() - 1; slot.cache_tokens.push_back(id); - slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); @@ -3903,6 +4010,7 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, + { "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, @@ -3950,9 +4058,10 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, + const std::vector & files, const std::function & is_connection_closed, httplib::Response & res, - oaicompat_type oaicompat) { + oaicompat_type oaicompat) -> void { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); if (ctx_server.params_base.embedding) { @@ -3969,15 +4078,69 @@ int main(int argc, char ** argv) { // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); - tasks.reserve(tokenized_prompts.size()); - for (size_t i = 0; i < tokenized_prompts.size(); i++) { + // process files + mtmd::bitmaps bitmaps; + const bool has_mtmd = ctx_server.mctx != nullptr; + { + if (!has_mtmd && !files.empty()) { + throw std::runtime_error("This server does not support multimodal"); + } + for (auto & file : files) { + mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); + if (!bmp.ptr) { + throw std::runtime_error("Failed to load image"); + } + // calculate bitmap hash (for KV caching) + std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + bmp.set_id(hash.c_str()); + bitmaps.entries.push_back(std::move(bmp)); + } + } + + // process prompt + std::vector inputs; + if (oaicompat && !prompt.is_string()) { + throw std::runtime_error("prompt must be a string"); + } + + if (oaicompat && has_mtmd) { + // multimodal + std::string prompt_str = prompt.get(); + mtmd_input_text inp_txt = { + prompt_str.c_str(), + /* add_special */ true, + /* parse_special */ true, + }; + mtmd::input_chunks chunks(mtmd_input_chunks_init()); + auto bitmaps_c_ptr = bitmaps.c_ptr(); + int32_t tokenized = mtmd_tokenize(ctx_server.mctx, + chunks.ptr.get(), + &inp_txt, + bitmaps_c_ptr.data(), + bitmaps_c_ptr.size()); + if (tokenized != 0) { + throw std::runtime_error("Failed to tokenize prompt"); + } + + server_tokens tmp(chunks, true); + inputs.push_back(std::move(tmp)); + } else { + // non-multimodal version + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (auto & p : tokenized_prompts) { + auto tmp = server_tokens(p, ctx_server.mctx != nullptr); + inputs.push_back(std::move(tmp)); + } + } + + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, @@ -4059,9 +4222,11 @@ int main(int argc, char ** argv) { const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); @@ -4069,9 +4234,11 @@ int main(int argc, char ** argv) { const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = oaicompat_completion_params_parse(json::parse(req.body)); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_COMPLETION); @@ -4146,9 +4313,11 @@ int main(int argc, char ** argv) { tokenized_prompts[0] ); - return handle_completions_impl( + std::vector files; // dummy + handle_completions_impl( SERVER_TASK_TYPE_INFILL, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_NONE); // infill is not OAI compatible @@ -4162,11 +4331,19 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; + json data = oaicompat_completion_params_parse( + body, + params.use_jinja, + params.reasoning_format, + ctx_server.chat_templates.get(), + ctx_server.mctx, + files); - return handle_completions_impl( + handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, data, + files, req.is_connection_closed, res, OAICOMPAT_TYPE_CHAT); @@ -4175,7 +4352,14 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); + std::vector files; // dummy, unused + json data = oaicompat_completion_params_parse( + body, + params.use_jinja, + params.reasoning_format, + ctx_server.chat_templates.get(), + ctx_server.mctx, + files); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4280,7 +4464,7 @@ int main(int argc, char ** argv) { } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { @@ -4300,7 +4484,7 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr); // OAI-compat task.params.oaicompat = oaicompat; @@ -4394,13 +4578,14 @@ int main(int argc, char ** argv) { std::unordered_set task_ids; { std::vector tasks; - std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); tasks.reserve(tokenized_docs.size()); for (size_t i = 0; i < tokenized_docs.size(); i++) { + auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr); tasks.push_back(std::move(task)); } diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py new file mode 100644 index 000000000..7cc4096f1 --- /dev/null +++ b/tools/server/tests/unit/test_vision_api.py @@ -0,0 +1,59 @@ +import pytest +from utils import * +import base64 +import requests + +server: ServerProcess + +IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" +IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + +response = requests.get(IMG_URL_0) +response.raise_for_status() # Raise an exception for bad status codes +IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinygemma3() + + +@pytest.mark.parametrize( + "prompt, image_url, success, re_content", + [ + # test model is trained on CIFAR-10, but it's quite dumb due to small size + ("What is this:\n", IMG_URL_0, True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log + ("What is this:\n", IMG_URL_1, True, "(frog)+"), + ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "malformed", False, None), + ("What is this:\n", "https://google.com/404", False, None), # non-existent image + ("What is this:\n", "https://ggml.ai", False, None), # non-image data + ] +) +def test_vision_chat_completion(prompt, image_url, success, re_content): + global server + server.start(timeout_seconds=60) # vision model may take longer to load due to download size + if image_url == "IMG_BASE64_0": + image_url = IMG_BASE64_0 + res = server.make_request("POST", "/chat/completions", data={ + "temperature": 0.0, + "top_k": 1, + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": { + "url": image_url, + }}, + ]}, + ], + }) + if success: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 4dc2062a8..27a0f0356 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -88,6 +88,7 @@ class ServerProcess: chat_template: str | None = None chat_template_file: str | None = None server_path: str | None = None + mmproj_url: str | None = None # session variables process: subprocess.Popen | None = None @@ -194,6 +195,8 @@ class ServerProcess: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.mmproj_url: + server_args.extend(["--mmproj-url", self.mmproj_url]) args = [str(arg) for arg in [server_path, *server_args]] print(f"tests: starting server with: {' '.join(args)}") @@ -379,6 +382,21 @@ class ServerPreset: server.server_reranking = True return server + @staticmethod + def tinygemma3() -> ServerProcess: + server = ServerProcess() + # mmproj is already provided by HF registry API + server.model_hf_repo = "ggml-org/tinygemma3-GGUF" + server.model_hf_file = "tinygemma3-Q8_0.gguf" + server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf" + server.model_alias = "tinygemma3" + server.n_ctx = 1024 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 4 + server.seed = 42 + return server + def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: """ diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b497959fd..23163f4fe 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -3,7 +3,9 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "arg.h" // common_remote_get_content #include "base64.hpp" +#include "mtmd.h" // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 @@ -21,6 +23,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" @@ -41,6 +44,8 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +using raw_buffer = std::vector; + template static T json_value(const json & body, const std::string & key, const T & default_value) { // Fallback null to default value @@ -386,7 +391,7 @@ static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); } -static inline std::vector base64_decode(const std::string & encoded_string) { +static inline raw_buffer base64_decode(const std::string & encoded_string) { int i = 0; int j = 0; int in_ = 0; @@ -396,7 +401,7 @@ static inline std::vector base64_decode(const std::string & encoded_str uint8_t char_array_4[4]; uint8_t char_array_3[3]; - std::vector ret; + raw_buffer ret; while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; @@ -579,7 +584,9 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const struct common_chat_templates * tmpls) + const struct common_chat_templates * tmpls, + bool allow_non_text, + std::vector & out_files) { json llama_params; @@ -627,8 +634,77 @@ static json oaicompat_completion_params_parse( } } + // get input files + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + json messages = body.at("messages"); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array"); + } + for (auto & msg : messages) { + json & content = msg.at("content"); + if (content.is_string()) { + continue; + } + + if (!content.is_array()) { + throw std::runtime_error("Expected 'content' to be a string or an array"); + } + + for (auto & p : content) { + std::string type = json_value(p, "type", std::string()); + json image_url = json_value(p, "image_url", json::object()); + if (type == "image_url") { + if (!allow_non_text) { + throw std::runtime_error("image input is not supported by this server"); + } + + std::string url = json_value(image_url, "url", std::string()); + if (string_starts_with(url, "http")) { + // download remote image + // TODO @ngxson : maybe make these params configurable + common_remote_params params; + params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.max_size = 1024 * 1024 * 10; // 10MB + params.timeout = 10; // seconds + SRV_INF("downloading image from '%s'\n", url.c_str()); + auto res = common_remote_get_content(url, params); + if (200 <= res.first && res.first < 300) { + SRV_INF("downloaded %ld bytes\n", res.second.size()); + raw_buffer data; + data.insert(data.end(), res.second.begin(), res.second.end()); + out_files.push_back(data); + } else { + throw std::runtime_error("Failed to download image"); + } + + } else { + // try to decode base64 image + std::vector parts = string_split(url, /*separator*/ ','); + if (parts.size() != 2) { + throw std::runtime_error("Invalid image_url.url value"); + } else if (!string_starts_with(parts[0], "data:image/")) { + throw std::runtime_error("Invalid image_url.url format: " + parts[0]); + } else if (!string_ends_with(parts[0], "base64")) { + throw std::runtime_error("image_url.url must be base64 encoded"); + } else { + auto base64_data = parts[1]; + auto decoded_data = base64_decode(base64_data); + out_files.push_back(decoded_data); + } + } + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = MTMD_DEFAULT_IMAGE_MARKER; + p.erase("image_url"); + } + } + } + common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); @@ -935,3 +1011,286 @@ static std::vector parse_lora_request( return lora; } + +// +// utils for interacting with libmtmd +// (may need to refactor in near future) +// + +/** + * server_tokens is a helper to manage the input tokens and image for the server. + * it is made this way to simplify the logic of KV cache management. + */ +struct server_tokens { + bool has_mtmd = false; + +private: // disallow accessing these members directly, risking out-of-sync + + // map a **start** position in tokens to the image chunk + std::unordered_map map_pos_to_image; + + // list of tokens + // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token + // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** + // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + llama_tokens tokens; + + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // pos 0 1 2 3 4 5 6 7 8 9 + // map_pos_to_image will contain: {5, img0}, {8, img1} + +public: + server_tokens() = default; + ~server_tokens() = default; + + // Prevent copying + server_tokens(const server_tokens&) = delete; + server_tokens& operator=(const server_tokens&) = delete; + + // Allow moving (usually implicitly generated if members are movable) + server_tokens(server_tokens&&) = default; + server_tokens& operator=(server_tokens&&) = default; + + // Allow accessing elements using [] operator + llama_token operator[](size_t index) { return tokens[index]; } + const llama_token& operator[](size_t index) const { return tokens[index]; } + + server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { + for (size_t i = 0; i < mtmd_chunks.size(); ++i) { + push_back(mtmd_chunks[i]); + } + } + + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + + // for debugging + std::string str() const { + std::ostringstream oss; + oss << "tokens: "; + for (const auto & t : tokens) { + if (t == LLAMA_TOKEN_NULL) { + oss << " "; + } else { + oss << t << " "; + } + } + oss << "\n"; + oss << "image pos: "; + for (const auto & it : map_pos_to_image) { + oss << it.first << ", "; + } + return oss.str(); + } + + const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { + auto it = map_pos_to_image.find(pos); + if (it != map_pos_to_image.end()) { + return it->second; + } else { + throw std::runtime_error("Chunk not found"); + } + } + + void push_back(llama_token tok) { + if (tok == LLAMA_TOKEN_NULL) { + throw std::runtime_error("Invalid token"); + } + tokens.emplace_back(tok); + } + + // will create a copy of the chunk if it contains non-text data + void push_back(const mtmd_input_chunk * chunk) { + auto type = mtmd_input_chunk_get_type(chunk); + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + GGML_ASSERT(has_mtmd); + auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk); + const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + llama_pos start_pos = tokens.size(); + for (int i = 0; i < n_pos; ++i) { + tokens.emplace_back(LLAMA_TOKEN_NULL); + } + mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); + map_pos_to_image[start_pos] = std::move(new_chunk); + } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + for (size_t i = 0; i < n_tokens; ++i) { + push_back(text_tokens[i]); + } + } else { + GGML_ABORT("Invalid chunk type"); + } + } + + // for compatibility with context shift and prompt truncation + void insert(const llama_tokens & inp_tokens) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); + } + + // for compatibility with speculative decoding, ctx shift, slot save/load + const llama_tokens & get_text_tokens() const { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + return tokens; + } + + // for compatibility with speculative decoding + void set_token(llama_pos pos, llama_token id) { + GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + tokens[pos] = id; + } + + size_t size() const { + return tokens.size(); + } + + bool empty() const { + return tokens.empty(); + } + + void clear() { + tokens.clear(); + } + + void resize(size_t n) { + GGML_ASSERT(n <= tokens.size()); + if (has_mtmd) { + // we throw an error if we try to remove a token in the middle of an image + // for ex. with input of 5 text tokens and 2 images: + // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] + // n 1 2 3 4 5 6 7 8 9 10 + // allowed to resize ^ ^ + // disallowed to resize ^ ^ ^ + if (n > 0) { + llama_token last_token = tokens[n - 1]; + // make sure we never remove tokens in the middle of an image + if (last_token == LLAMA_TOKEN_NULL) { + find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + } + } + // remove all image chunks that are not used anymore + for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { + llama_pos pos = it->first; + if (pos >= (llama_pos)n) { + it = map_pos_to_image.erase(it); + } else { + ++it; + } + } + } + tokens.resize(n); + } + + std::string detokenize(const llama_context * ctx, bool special) const { + llama_tokens text_tokens; + text_tokens.reserve(tokens.size()); + for (const auto & t : tokens) { + if (t != LLAMA_TOKEN_NULL) { + text_tokens.push_back(t); + } + } + return common_detokenize(ctx, text_tokens, special); + } + + size_t get_common_prefix(const server_tokens & b) const { + size_t max_idx = std::min(tokens.size(), b.tokens.size()); + for (size_t i = 0; i < max_idx; ++i) { + auto & ai = tokens[i]; + auto & bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + GGML_ASSERT(has_mtmd); + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); + const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get()); + const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get()); + std::string ai_id = mtmd_image_tokens_get_id(a_img); + std::string bi_id = mtmd_image_tokens_get_id(b_img); + size_t a_pos = mtmd_image_tokens_get_n_pos(a_img); + size_t b_pos = mtmd_image_tokens_get_n_pos(b_img); + if (ai_id == bi_id && a_pos == b_pos) { + GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen + i += a_pos - 1; // will be +1 by the for loop + continue; + } else { + return i; + } + } else if (ai == bi) { + continue; + } else { + return i; + } + } + return max_idx; // all tokens are equal + } + + // make sure all text tokens are within the vocab range + bool validate(const struct llama_context * ctx) const { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); + + for (size_t i = 0; i < tokens.size(); ++i) { + auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + try { + const auto & chunk = find_chunk(i); + const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get()); + size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + i += n_pos - 1; // will be +1 by the for loop + } catch (const std::exception & e) { + return false; + } + } else if (t < 0 || t >= n_vocab) { + return false; + } + } + return true; + } + + // encode and decode the image chunk + int32_t process_chunk( + llama_context * ctx, + mtmd_context * mctx, + llama_pos n_past, + int32_t seq_id, + llama_pos & n_pos_out) { + auto it = map_pos_to_image.find(n_past); + if (it == map_pos_to_image.end()) { + throw std::runtime_error("Chunk not found"); + } + SRV_INF("%s\n", "processing image..."); + int32_t n_batch = llama_n_batch(ctx); + int64_t t0 = ggml_time_ms(); + llama_pos new_n_past = n_past; + int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, + it->second.get(), // chunk + n_past, + seq_id, + n_batch, + true, // logits last + &new_n_past); + SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0); + if (result != 0) { + LOG_ERR("mtmd_helper_eval failed with status %d", result); + n_pos_out = n_past; + return result; + } + n_pos_out = new_n_past; + return 0; + } +}; + +// Computes FNV-1a hash of the data +static std::string fnv_hash(const uint8_t * data, size_t len) { + const uint64_t fnv_prime = 0x100000001b3ULL; + uint64_t hash = 0xcbf29ce484222325ULL; + + for (size_t i = 0; i < len; ++i) { + hash ^= data[i]; + hash *= fnv_prime; + } + return std::to_string(hash); +} From 7c28a74e0783f4bb74a246fb9f19bf212139e365 Mon Sep 17 00:00:00 2001 From: Helton Reis <47722840+HRKings@users.noreply.github.com> Date: Fri, 9 May 2025 17:15:39 -0300 Subject: [PATCH 06/33] chore(llguidance): use tagged version that does not break the build (#13413) --- common/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index ccadb5d1d..6b0011e4d 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -119,8 +119,8 @@ if (LLAMA_LLGUIDANCE) ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance - # v0.7.10: - GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8 + # v0.7.19 (+ fancy-regex build fix): + GIT_TAG b59f98f85269892a7de3d3641ad155366f13daa6 PREFIX ${CMAKE_BINARY_DIR}/llguidance SOURCE_DIR ${LLGUIDANCE_SRC} BUILD_IN_SOURCE TRUE From dc1d2adfc0f4de84da7923866d00781ec5c4e666 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 9 May 2025 23:07:07 -0700 Subject: [PATCH 07/33] vulkan: scalar flash attention implementation (#13324) * vulkan: scalar flash attention implementation * vulkan: always use fp32 for scalar flash attention * vulkan: use vector loads in scalar flash attention shader * vulkan: remove PV matrix, helps with register usage * vulkan: reduce register usage in scalar FA, but perf may be slightly worse * vulkan: load each Q value once. optimize O reduction. more tuning * vulkan: support q4_0/q8_0 KV in scalar FA * CI: increase timeout to accommodate newly-supported tests * vulkan: for scalar FA, select between 1 and 8 rows * vulkan: avoid using Float16 capability in scalar FA --- .github/workflows/build.yml | 2 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 243 +++++---- .../vulkan-shaders/flash_attn.comp | 483 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 12 +- 4 files changed, 646 insertions(+), 94 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae066697d..b62720f30 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -307,7 +307,7 @@ jobs: run: | cd build # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 2700 + ctest -L main --verbose --timeout 3600 ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2dc2883a7..e2b357fdc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -275,6 +275,7 @@ struct vk_device_struct { bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_add; + bool subgroup_shuffle; bool integer_dot_product; @@ -402,12 +403,20 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; @@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_num_small_rows(bool scalar) { + return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +} + +static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + if (scalar) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 64}; + return {get_fa_num_small_rows(scalar), 32}; } + // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { return {64, 32}; @@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; + +#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + +#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, true, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - - auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; - auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; - }; - -#define CREATE_FA2(TYPE, NAMELC, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - -#define CREATE_FA(TYPE, NAMELC) \ - CREATE_FA2(TYPE, NAMELC, 64) \ - CREATE_FA2(TYPE, NAMELC, 80) \ - CREATE_FA2(TYPE, NAMELC, 96) \ - CREATE_FA2(TYPE, NAMELC, 112) \ - CREATE_FA2(TYPE, NAMELC, 128) \ - CREATE_FA2(TYPE, NAMELC, 256) - - CREATE_FA(GGML_TYPE_F16, f16) - CREATE_FA(GGML_TYPE_Q4_0, q4_0) - CREATE_FA(GGML_TYPE_Q4_1, q4_1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0) - CREATE_FA(GGML_TYPE_Q5_1, q5_1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0) - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //CREATE_FA(GGML_TYPE_Q2_K, q2_k) - //CREATE_FA(GGML_TYPE_Q3_K, q3_k) - //CREATE_FA(GGML_TYPE_Q4_K, q4_k) - //CREATE_FA(GGML_TYPE_Q5_K, q5_k) - //CREATE_FA(GGML_TYPE_Q6_K, q6_k) - //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) - //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) - //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) - //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) - //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) - //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) - //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) - //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) + CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + } +#endif +#undef CREATE_FA2 #undef CREATE_FA +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ @@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); + bool scalar = !ctx->device->coopmat2; + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar FA, we can use the "large" size to accommodate qga. + // For coopmat FA, we always use the small size (which is still pretty large for gqa). + const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + vk_pipeline *pipelines; // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= flash_attention_num_small_rows; - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - assert(!"unsupported D value"); - return; + bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= get_fa_num_small_rows(scalar); + + if (scalar) { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + } else { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } } assert(pipelines); @@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); - uint32_t gqa_ratio = 1; - uint32_t qk_ratio = neq2 / nek2; - uint32_t workgroups_x = (uint32_t)neq1; - uint32_t workgroups_y = (uint32_t)neq2; - uint32_t workgroups_z = (uint32_t)neq3; - - if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { - // grouped query attention - make the N dimension equal to gqa_ratio, reduce - // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 - // and change addressing calculations to index Q's dimension 2. - gqa_ratio = qk_ratio; - N = gqa_ratio; - workgroups_y /= N; - } - uint32_t split_kv = KV; uint32_t split_k = 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { @@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { - return false; - } + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; switch (op->src[0]->ne[0]) { case 64: case 80: @@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case 112: case 128: case 256: - case 575: // DeepSeek MLA break; default: return false; @@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[1]->type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } break; default: return false; } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } return true; } case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 000000000..e6545160d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,483 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared vec4 tmpshv4[gl_WorkGroupSize.x]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 382f190f2..d196137eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -421,7 +421,6 @@ void process_shaders() { #endif } -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; @@ -432,6 +431,7 @@ void process_shaders() { } if (tname == "bf16") continue; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); @@ -440,9 +440,17 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } } } -#endif for (const auto& tname : type_names) { // mul mat vec From 7fef11766cdeb9fa7bbbe3db13580616b7d3d599 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 10 May 2025 08:16:29 +0200 Subject: [PATCH 08/33] arg : add env var to control mmproj (#13416) * arg : add env var to control mmproj * small note about -hf --mmproj --- common/arg.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index f67e0d96d..e0f1d998f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2204,32 +2204,33 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING")); add_opt(common_arg( {"--mmproj"}, "FILE", - "path to a multimodal projector file. see tools/mtmd/README.md", + "path to a multimodal projector file. see tools/mtmd/README.md\n" + "note: if -hf is used, this argument can be omitted", [](common_params & params, const std::string & value) { params.mmproj.path = value; } - ).set_examples(mmproj_examples)); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ")); add_opt(common_arg( {"--mmproj-url"}, "URL", "URL to a multimodal projector file. see tools/mtmd/README.md", [](common_params & params, const std::string & value) { params.mmproj.url = value; } - ).set_examples(mmproj_examples)); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_URL")); add_opt(common_arg( {"--no-mmproj"}, "explicitly disable multimodal projector, useful when using -hf", [](common_params & params) { params.no_mmproj = true; } - ).set_examples(mmproj_examples)); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ")); add_opt(common_arg( {"--no-mmproj-offload"}, "do not offload multimodal projector to GPU", [](common_params & params) { params.mmproj_use_gpu = false; } - ).set_examples(mmproj_examples)); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); add_opt(common_arg( {"--image"}, "FILE", "path to an image file. use with multimodal models. Specify multiple times for batching", From d8919424f1dee7dc1638349c616f2ef5d2ee16fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 10 May 2025 09:16:52 +0200 Subject: [PATCH 09/33] CUDA: fix FlashAttention on Turing (#13415) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 2b6bdc30c..b2f95fa3f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -546,7 +546,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; const int i0_diff = i0_stop - i0_start; - if (nstages == 1) { + if (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); From 053367d149f778cdabc356ee3024494e0dd53223 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 10 May 2025 16:26:42 +0200 Subject: [PATCH 10/33] mtmd : support InternVL 2.5 and 3 (#13422) * convert : internvl support * InternVL3-1B working * fix regression * rm mobilevlm from test * fix conversion * add test for internvl * add to list of pre-quant * restore boi/eoi check * add clarify comment for norm eps --- convert_hf_to_gguf.py | 76 +++++++++++++++++- docs/multimodal.md | 8 ++ gguf-py/gguf/constants.py | 7 ++ gguf-py/gguf/tensor_mapping.py | 12 +++ tools/mtmd/README.md | 1 + tools/mtmd/clip-impl.h | 11 +-- tools/mtmd/clip.cpp | 140 ++++++++++++++++++++++++++++----- tools/mtmd/mtmd.cpp | 7 ++ tools/mtmd/tests.sh | 6 +- 9 files changed, 243 insertions(+), 25 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf6bc6838..e5c397fee 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -426,7 +426,11 @@ class ModelBase: logger.warning(f"Failed to load model config from {dir_model}: {e}") logger.warning("Trying to load config.json instead") with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + config = json.load(f) + if "llm_config" in config: + # rename for InternVL + config["text_config"] = config["llm_config"] + return config @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -2606,6 +2610,11 @@ class Qwen2Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if self.hf_arch == "Qwen2Model": name = f"model.{name}" # map to Qwen2ForCausalLM tensors + if "language_model." in name: + name = name.replace("language_model.", "") # for InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] yield from super().modify_tensors(data_torch, name, bid) @@ -2709,6 +2718,62 @@ class Qwen2VLVisionModel(VisionModel): return [] # skip other tensors +@ModelBase.register("InternVisionModel") +class InternVisionModel(VisionModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + # downsample_ratio + downsample_ratio = self.global_config.get("downsample_ratio") + assert downsample_ratio is not None + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("vision_model") or name.startswith("mlp"): + # process visual tensors + # correct name + if name.startswith("vision_model"): + name = "vision_tower." + name + if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + name += ".weight" + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), + ] + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -3360,6 +3425,11 @@ class InternLM2Model(TextModel): head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch @@ -3433,6 +3503,10 @@ class InternLM3Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): diff --git a/docs/multimodal.md b/docs/multimodal.md index efed473a3..090583f97 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -66,4 +66,12 @@ NOTE: some models may require large context window, for example: `-c 8192` # Mistral Small 3.1 24B (IQ2_M quantization) (tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF + +# InternVL 2.5 and 3 +(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF +(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF +(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF ``` diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7dd7bb6d1..ae5ce71ae 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -491,6 +491,8 @@ class MODEL_TENSOR(IntEnum): V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() + V_LAYER_SCALE_1 = auto() + V_LAYER_SCALE_2 = auto() V_PRE_NORM = auto() V_POST_NORM = auto() V_MM_INP_NORM = auto() @@ -748,6 +750,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", + MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", @@ -786,6 +790,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_LAYER_SCALE_1, + MODEL_TENSOR.V_LAYER_SCALE_2, MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, @@ -2167,6 +2173,7 @@ class VisionProjectorType: PIXTRAL = "pixtral" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" + INTERNVL = "internvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 003b0172c..bf7ec3257 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -905,6 +905,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", + "mlp1.{bid}", # InternVL ), MODEL_TENSOR.V_MMPROJ_PEG: ( @@ -955,6 +956,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_INPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral @@ -963,6 +965,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_OUTPUT: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral @@ -971,6 +974,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral @@ -1000,6 +1004,14 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), + MODEL_TENSOR.V_LAYER_SCALE_1: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + ), + + MODEL_TENSOR.V_LAYER_SCALE_2: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + ), + MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index 06e1fd097..ab258ea17 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -48,6 +48,7 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint - Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen)) - [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) +- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported) For older models, please refer to the relevant guide for instructions on how to obtain or create them: diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index fb780e9de..e9c8646e1 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -33,9 +33,6 @@ #define KEY_PROJ_TYPE "clip.projector_type" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" -#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl -#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl - #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" @@ -60,8 +57,10 @@ #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" -#define TN_LN_2 "%s.blk.%d.ln2.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm +#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm +#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale +#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale #define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_POST "%s.post_ln.%s" #define TN_LLAVA_PROJ "mm.%d.%s" @@ -105,6 +104,7 @@ enum projector_type { PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, + PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -119,6 +119,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, + { PROJECTOR_TYPE_INTERNVL, "internvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 1a81c1fcd..dfe7ac91c 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -215,6 +215,10 @@ struct clip_layer { // layernorm 2 ggml_tensor * ln_2_w = nullptr; ggml_tensor * ln_2_b = nullptr; + + // layer scale (no bias) + ggml_tensor * ls_1_w = nullptr; + ggml_tensor * ls_2_w = nullptr; }; struct clip_vision_model { @@ -589,6 +593,9 @@ struct clip_graph { // Qwen2VL and Qwen2.5VL use M-RoPE ggml_cgraph * build_qwen2vl() { + GGML_ASSERT(model.patch_bias == nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; const bool use_window_attn = hparams.n_wa_pattern > 0; const int n_wa_pattern = hparams.n_wa_pattern; @@ -625,10 +632,6 @@ struct clip_graph { n_embd, n_patches_x * n_patches_y, batch_size); } - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - } - ggml_tensor * inpL = inp; ggml_tensor * window_mask = nullptr; ggml_tensor * window_idx = nullptr; @@ -859,6 +862,67 @@ struct clip_graph { return gf; } + ggml_cgraph * build_internvl() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; + ggml_tensor * inp = build_inp(); + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + model.position_embeddings, + nullptr); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + { + const int scale_factor = model.hparams.proj_scale_factor; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + cur->ne[1] * cur->ne[2]); + } + + // projector (always using GELU activation) + { + // projector LayerNorm uses pytorch's default eps = 1e-5 + // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); + cur = ggml_add(ctx0, cur, model.mm_3_b); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -890,10 +954,6 @@ struct clip_graph { ggml_tensor * inp = build_inp(); - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - } - // concat class_embeddings and patch_embeddings if (model.class_embedding) { inp = ggml_concat(ctx0, inp, model.class_embedding, 1); @@ -1260,11 +1320,6 @@ private: ggml_tensor * learned_pos_embd, std::function add_pos ) { - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - cb(inp, "patch_bias", -1); - } - if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); @@ -1324,6 +1379,11 @@ private: cb(cur, "attn_out", il); } + if (layer.ls_1_w) { + cur = ggml_mul(ctx0, cur, layer.ls_1_w); + cb(cur, "attn_out_scaled", il); + } + // re-add the layer input, e.g., residual cur = ggml_add(ctx0, cur, inpL); @@ -1344,6 +1404,11 @@ private: cb(cur, "ffn_out", il); + if (layer.ls_2_w) { + cur = ggml_mul(ctx0, cur, layer.ls_2_w); + cb(cur, "ffn_out_scaled", il); + } + // residual 2 cur = ggml_add(ctx0, inpL, cur); cb(cur, "layer_out", il); @@ -1365,6 +1430,10 @@ private: ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } return inp; } @@ -1627,6 +1696,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_minicpmv(); } break; + case PROJECTOR_TYPE_INTERNVL: + { + res = graph.build_internvl(); + } break; default: { res = graph.build_llava(); @@ -1790,6 +1863,7 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; @@ -1897,6 +1971,9 @@ struct clip_model_loader { layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); + layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias + layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias + layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); @@ -1904,7 +1981,7 @@ struct clip_model_loader { layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); - // new naming + // ffn layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false); @@ -2052,6 +2129,15 @@ struct clip_model_loader { vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_INTERNVL: + { + vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -2838,7 +2924,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 + || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution + ) { clip_image_u8 resized_image; int sz = params.image_size; image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz}); @@ -2988,9 +3076,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP + || ctx->proj_type == PROJECTOR_TYPE_LDPV2 + || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; - n_patches += 2; // for BOI and EOI token embeddings + if (ctx->vision_model.mm_glm_tok_boi) { + n_patches += 2; // for BOI and EOI token embeddings + } } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { if (ctx->minicpmv_version == 2) { n_patches = 96; @@ -3013,7 +3105,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int n_per_side = params.image_size / params.patch_size; int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; n_patches = n_per_side_2d_pool * n_per_side_2d_pool; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) { + // both W and H are divided by proj_scale_factor n_patches /= (params.proj_scale_factor * params.proj_scale_factor); } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { int n_merge = params.spatial_merge_size; @@ -3408,6 +3501,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: { // do nothing } break; @@ -3434,6 +3528,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // the last node is the embedding tensor ggml_tensor * embeddings = ggml_graph_node(gf, -1); + // sanity check (only support batch size of 1 for now) + const int n_tokens_out = embeddings->ne[1]; + const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get()); + if (n_tokens_out != expected_n_tokens_out) { + LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out); + GGML_ABORT("Invalid number of output tokens"); + } + // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); @@ -3604,6 +3706,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->vision_model.projection->ne[1]; + case PROJECTOR_TYPE_INTERNVL: + return ctx->vision_model.mm_3_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2fecf08a4..f1b957394 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -252,6 +252,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } + else if (proj_type == PROJECTOR_TYPE_INTERNVL) { + // ... (image embeddings) ... + marker_modified = "" + ctx->image_marker + ""; + string_replace_all(prompt_modified, ctx->image_marker, marker_modified); + + } + // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix // for glm-edge, BOI and EOI token's embeddings are not present in the text model diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 22c237497..05ac7a04d 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -40,7 +40,6 @@ add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna" @@ -50,6 +49,8 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" +add_test "llama-mtmd-cli" "ggml-org/InternVL2_5-1B-GGUF:Q8_0" +add_test "llama-mtmd-cli" "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then @@ -59,6 +60,8 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big fi @@ -70,6 +73,7 @@ fi # this model has broken chat template, not usable # add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" +# add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" ############### From b064a51a4e7246fa43a3307f58e59f65e6be170a Mon Sep 17 00:00:00 2001 From: Thammachart Chinvarapon <1731496+Thammachart@users.noreply.github.com> Date: Sat, 10 May 2025 21:34:48 +0700 Subject: [PATCH 11/33] ci: free_disk_space flag enabled for intel variant (#13426) before cleanup: 20G after cleanup: 44G after all built and pushed: 24G https://github.com/Thammachart/llama.cpp/actions/runs/14945093573/job/41987371245 --- .github/workflows/docker.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f8e072fb8..2067927be 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -42,8 +42,7 @@ jobs: - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - # Note: the intel images are failing due to an out of disk space error - # - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true } From 43dfd741a5da77982e0a94d1e522a2481dfb914d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 10 May 2025 17:19:52 +0200 Subject: [PATCH 12/33] llguidance : set tokenizer slices to default (#13424) --- common/llguidance.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 8bff89ea4..adce620e4 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -189,6 +189,7 @@ static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, /* .use_approximate_greedy_tokenize_fn = */ false, /* .tokenize_user_data = */ vocab, + /* .slices = */ nullptr, }; char error_buffer[1024]; From 3b24d26c22b9d92b5aa39930988700c84e524cf4 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 10 May 2025 18:44:49 +0200 Subject: [PATCH 13/33] server : update docs (#13432) --- docs/multimodal.md | 2 +- tools/server/README.md | 54 ++++++++++++++++++++++++++++++------------ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/docs/multimodal.md b/docs/multimodal.md index 090583f97..82b9411e4 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -6,7 +6,7 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools To enable it, can use use one of the 2 methods below: -- Use `-hf` option with a [supported model](../../docs/multimodal.md) +- Use `-hf` option with a supported model (see a list of pre-quantized model below) - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` - To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf` - Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively diff --git a/tools/server/README.md b/tools/server/README.md index 972ca384e..7b944c35b 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,13 +7,15 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes - * Reranking endoint (WIP: https://github.com/ggml-org/llama.cpp/pull/9510) + * Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching - * Multimodal (wip) + * Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support * Monitoring endpoints * Schema-constrained JSON response format * [Function calling](../../docs/function-calling.md) / tool use for ~any model + * Speculative decoding + * Easy-to-use web UI The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216). @@ -27,6 +29,7 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `--completion-bash` | print source-able bash completion script for llama.cpp | | `--verbose-prompt` | print a verbose prompt before generation (default: false) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | @@ -41,7 +44,7 @@ The project is under active development, and we are [looking for feedback and co | `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)
| | `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) | | `-c, --ctx-size N` | size of the prompt context (default: 4096, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | -| `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)
(env: LLAMA_ARG_N_PREDICT) | +| `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity)
(env: LLAMA_ARG_N_PREDICT) | | `-b, --batch-size N` | logical maximum batch size (default: 2048)
(env: LLAMA_ARG_BATCH) | | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | @@ -69,6 +72,7 @@ The project is under active development, and we are [looking for feedback and co | `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | | `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | | `--list-devices` | print list of available devices and exit | +| `--override-tensor, -ot =,...` | override tensor buffer type | | `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | @@ -82,15 +86,18 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | | `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | -| `-hfr, --hf-repo REPO` | Hugging Face model repository (default: unused)
(env: LLAMA_ARG_HF_REPO) | -| `-hff, --hf-file FILE` | Hugging Face model file (default: unused)
(env: LLAMA_ARG_HF_FILE) | +| `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | +| `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | +| `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | +| `-hfv, -hfrv, --hf-repo-v /[:quant]` | Hugging Face model repository for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_REPO_V) | +| `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_FILE_V) | | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | | `--log-disable` | Log disable | | `--log-file FNAME` | Log to file | | `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | -| `--log-prefix` | Enable prefx in log messages
(env: LLAMA_LOG_PREFIX) | +| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | | `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | @@ -98,9 +105,9 @@ The project is under active development, and we are [looking for feedback and co | Argument | Explanation | | -------- | ----------- | -| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: dry;top_k;typ_p;top_p;min_p;xtc;temperature) | +| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) | | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | -| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) | +| `--sampling-seq, --sampler-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | | `--temp N` | temperature (default: 0.8) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled) | @@ -127,22 +134,26 @@ The project is under active development, and we are [looking for feedback and co | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | -| `--jinja` | Enable experimental Jinja templating engine (required for tool use) | -| `--reasoning-format FORMAT` | Controls extraction of model thinking traces and the format / field in which they are returned (default: `deepseek`; allowed values: `deepseek`, `none`; requires `--jinja`). `none` will leave thinking traces inline in `message.content` in a model-specific format, while `deepseek` will return them separately under `message.reasoning_content` | +| `-jf, --json-schema-file FILE` | File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | + **Example-specific params** | Argument | Explanation | | -------- | ----------- | -| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | +| `--no-context-shift` | disables context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `-sp, --special` | special tokens output enabled (default: false) | | `--no-warmup` | skip warming up the model with an empty run | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | | `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | +| `--mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md
note: if -hf is used, this argument can be omitted
(env: LLAMA_ARG_MMPROJ) | +| `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | +| `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf
(env: LLAMA_ARG_NO_MMPROJ) | +| `--no-mmproj-offload` | do not offload multimodal projector to GPU
(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) | | `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | -| `--host HOST` | ip address to listen (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) | +| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) | | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--no-webui` | Disable the Web UI (default: enabled)
(env: LLAMA_ARG_NO_WEBUI) | @@ -160,16 +171,29 @@ The project is under active development, and we are [looking for feedback and co | `--props` | enable changing global properties via POST /props (default: disabled)
(env: LLAMA_ARG_ENDPOINT_PROPS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
list of built-in templates:
chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--jinja` | use jinja template for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | +| `--reasoning-format FORMAT` | reasoning format (default: deepseek; allowed values: deepseek, none)
controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).
only supported for non-streamed responses
(env: LLAMA_ARG_THINK) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | -| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 5)
(env: LLAMA_ARG_DRAFT_MIN) | -| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.9)
(env: LLAMA_ARG_DRAFT_P_MIN) | +| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_DRAFT_MIN) | +| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)
(env: LLAMA_ARG_DRAFT_P_MIN) | | `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE_DRAFT) | | `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_MODEL_DRAFT) | +| `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | +| `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | +| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | +| `--embd-e5-small-en-default` | use default e5-small-v2 model (note: can download weights from the internet) | +| `--embd-gte-small-default` | use default gte-small model (note: can download weights from the internet) | +| `--fim-qwen-1.5b-default` | use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet) | +| `--fim-qwen-3b-default` | use default Qwen 2.5 Coder 3B (note: can download weights from the internet) | +| `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) | +| `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) | +| `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. From 15e6125a397f6086c1dfdf7584acdb7c730313dc Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 10 May 2025 19:57:54 +0200 Subject: [PATCH 14/33] mtmd : add hard limit on image resolution for qwen2vl / qwen2.5vl (#13434) * mtmd : add hard limit on image resolution for qwen2vl / qwen2.5vl * fix typo --- tools/mtmd/clip-impl.h | 3 +++ tools/mtmd/clip.cpp | 35 +++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index e9c8646e1..d7b788bf9 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -92,6 +92,9 @@ #define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" #define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +// align x to upper multiple of n +#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) + enum projector_type { PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP_NORM, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index dfe7ac91c..0ebe81b07 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -174,6 +174,10 @@ struct clip_hparams { int32_t n_layer; int32_t proj_scale_factor = 0; // idefics3 + // for models using dynamic image size, we need to have a smaller image size to warmup + // otherwise, user will get OOM everytime they load the model + int32_t warmup_image_size = 0; + ffn_op_type ffn_op = FFN_GELU; patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; @@ -1796,6 +1800,9 @@ struct clip_model_loader { get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + // default warmup value + hparams.warmup_image_size = hparams.image_size; + ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM || ctx_clip.proj_type == PROJECTOR_TYPE_LDP @@ -1870,6 +1877,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_PIXTRAL: { hparams.rope_theta = 10000.0f; + hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; case PROJECTOR_TYPE_GEMMA3: @@ -1880,8 +1888,19 @@ struct clip_model_loader { // test model (tinygemma3) has a different value, we optionally read it get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; + case PROJECTOR_TYPE_QWEN2VL: + { + // max image size = sqrt(max_pixels) + // https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json + hparams.image_size = 3584; + hparams.warmup_image_size = hparams.patch_size * 8; + } break; case PROJECTOR_TYPE_QWEN25VL: { + // max image size = sqrt(max_pixels) + // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json + hparams.image_size = 3584; + hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break; default: @@ -2185,13 +2204,14 @@ struct clip_model_loader { // create a fake batch clip_image_f32_batch batch; clip_image_f32_ptr img(clip_image_f32_init()); - img->nx = ctx_clip.vision_model.hparams.image_size; - img->ny = ctx_clip.vision_model.hparams.image_size; + img->nx = ctx_clip.vision_model.hparams.warmup_image_size; + img->ny = ctx_clip.vision_model.hparams.warmup_image_size; img->buf.resize(img->nx * img->ny * 3); batch.entries.push_back(std::move(img)); ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch); ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); + for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { ggml_backend_t backend = ctx_clip.backend_ptrs[i]; ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i]; @@ -2590,8 +2610,8 @@ struct image_manipulation { float target_width_f = static_cast(inp_size.width) * scale; float target_height_f = static_cast(inp_size.height) * scale; - int aligned_width = GGML_PAD((int)target_width_f, align_size); - int aligned_height = GGML_PAD((int)target_height_f, align_size); + int aligned_width = CLIP_ALIGN((int)target_width_f, align_size); + int aligned_height = CLIP_ALIGN((int)target_height_f, align_size); return {aligned_width, aligned_height}; } @@ -2910,10 +2930,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { clip_image_u8 resized; - auto patch_size = clip_get_patch_size(ctx) * 2; - int nx = ceil((float)img->nx / patch_size) * patch_size; - int ny = ceil((float)img->ny / patch_size) * patch_size; - image_manipulation::bicubic_resize(*img, resized, nx, ny); + auto patch_size = params.patch_size * 2; + auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); + image_manipulation::bicubic_resize(*img, resized, new_size.width, new_size.height); clip_image_f32_ptr img_f32(clip_image_f32_init()); // clip_image_f32_ptr res(clip_image_f32_init()); From d2a4ef05c60506ee48e7375eb36f2257de7ab0d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 10 May 2025 22:08:07 +0200 Subject: [PATCH 15/33] vocab : add ByteDance-Seed/Seed-Coder (#13423) --- convert_hf_to_gguf.py | 3 +++ convert_hf_to_gguf_update.py | 1 + include/llama.h | 1 + src/llama-vocab.cpp | 11 +++++++++++ 4 files changed, 16 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e5c397fee..a34ba2988 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -798,6 +798,9 @@ class TextModel(ModelBase): if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3": # ref: https://huggingface.co/mistral-community/pixtral-12b res = "pixtral" + if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": + # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base + res = "seed-coder" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 03a1d8d8c..5993a4c98 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -116,6 +116,7 @@ models = [ {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", }, {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, + {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, ] diff --git a/include/llama.h b/include/llama.h index a18f365bf..7d5f9d559 100644 --- a/include/llama.h +++ b/include/llama.h @@ -112,6 +112,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, }; enum llama_rope_type { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 50ded286f..e6b96ecc8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -415,6 +415,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1634,6 +1641,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "bailingmoe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } From 0208355f42bdab88a08507ead4a6302790a08323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 10 May 2025 22:22:48 +0200 Subject: [PATCH 16/33] CUDA: fix race conditions FlashAttention kernels (#13438) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 ++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 1 + 2 files changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2f95fa3f..9873ea755 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index ef0addc1d..d96e39212 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*D + tid] = -HALF_MAX_HALF; } + __syncthreads(); half2 VKQ[ncols] = {{0.0f, 0.0f}}; From 62d4250e52917dc4d634f054b1e2183ed7dee944 Mon Sep 17 00:00:00 2001 From: Thomas Germer <99991@users.noreply.github.com> Date: Sat, 10 May 2025 22:26:46 +0200 Subject: [PATCH 17/33] docs : Fix typo in InternVL3 model name (#13440) --- docs/multimodal.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/multimodal.md b/docs/multimodal.md index 82b9411e4..6a5d2b342 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -69,9 +69,9 @@ NOTE: some models may require large context window, for example: `-c 8192` # InternVL 2.5 and 3 (tool_name) -hf ggml-org/InternVL2_5-1B-GGUF -(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF +(tool_name) -hf ggml-org/InternVL2_5-4B-GGUF (tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF -(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF ``` From a634d75d1bb4034b8c945042ceea268b5b895730 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 11 May 2025 11:34:23 +0200 Subject: [PATCH 18/33] mtmd : move helpers to dedicated file (#13442) * mtmd : move helpers to dedicated file * fix windows build * rm redundant include --- tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/mtmd-helper.cpp | 310 +++++++++++++++++++++++++++++++++++ tools/mtmd/mtmd.cpp | 324 ++----------------------------------- tools/mtmd/mtmd.h | 1 + 4 files changed, 325 insertions(+), 311 deletions(-) create mode 100644 tools/mtmd/mtmd-helper.cpp diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 27b6d27e5..dfafa9cf8 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -28,6 +28,7 @@ endif() add_library(mtmd OBJECT mtmd.cpp + mtmd-helper.cpp mtmd.h clip.cpp clip.h diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp new file mode 100644 index 000000000..7a3288672 --- /dev/null +++ b/tools/mtmd/mtmd-helper.cpp @@ -0,0 +1,310 @@ +#include "mtmd.h" +#include "llama.h" + +#include +#include +#include + +#define LOG_INF(...) fprintf(stdout, __VA_ARGS__) +#define LOG_ERR(...) fprintf(stderr, __VA_ARGS__) + +size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { + size_t n_tokens = 0; + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_tokens += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); + } else { + GGML_ASSERT(false && "chunk type not supported"); + } + } + return n_tokens; +} + +llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { + llama_pos n_pos = 0; + for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { + auto chunk = mtmd_input_chunks_get(chunks, i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens_text; + mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); + n_pos += n_tokens_text; + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); + n_pos += mtmd_image_tokens_get_n_pos(tokens_image); + } else { + GGML_ASSERT(false && "chunk type not supported"); + } + } + return n_pos; +} + +// helper struct to make working with embd batch easier +// note: this will be removed after llama_batch_ext refactoring +struct decode_embd_batch { + int n_pos_per_embd; + int n_mmproj_embd; + std::vector pos; + std::vector pos_view; // used by mrope + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) { + pos .resize(n_tokens * n_pos_per_embd); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + } + + void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) { + seq_id_0[0] = seq_id; + for (int i = 0; i < batch.n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } + + void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { + GGML_ASSERT(n_pos_per_embd == 4); + seq_id_0[0] = seq_id; + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + int i = y * nx + x; + pos[i ] = pos_0; + pos[i + batch.n_tokens ] = pos_0 + y; + pos[i + batch.n_tokens * 2] = pos_0 + x; + pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused + } + } + for (int i = 0; i < batch.n_tokens; i++) { + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } + + llama_batch get_view(int offset, int n_tokens) { + llama_pos * pos_ptr; + pos_view.clear(); + pos_view.reserve(n_tokens * n_pos_per_embd); + if (n_pos_per_embd > 1) { + // mrope + // for example, with layout of src: 1234...1234...1234...1234... + // offset 2 will give us dst: 34...34...34...34... + for (int i = 0; i < n_pos_per_embd; i++) { + // assume n_tokens is less than or equal to batch.n_tokens + // batch.n_tokens is number of **total** tokens + // n_tokens is number of viewed token + size_t src_idx = i * batch.n_tokens + offset; + pos_view.insert(pos_view.end(), + pos.data() + src_idx, + pos.data() + src_idx + n_tokens); + } + pos_ptr = pos_view.data(); + } else { + // normal + pos_ptr = pos.data() + offset; + } + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ batch.embd + offset * n_mmproj_embd, + /*pos =*/ pos_ptr, + /*n_seq_id =*/ batch.n_seq_id + offset, + /*seq_id =*/ batch.seq_id + offset, + /*logits =*/ batch.logits + offset, + }; + } +}; + +// Helper function for decoding an image whose embeddings have already been calculated +int32_t mtmd_helper_decode_image_chunk( + mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + float * encoded_embd, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + llama_pos * new_n_past) { + if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) { + LOG_ERR("failed to decode image chunk: input chunk not of image type\n"); + return -1; + } + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + if (!image_tokens) { + LOG_ERR("failed to decode image chunk: image tokens are null\n"); + return -1; + } + + const llama_model * model = llama_get_model(lctx); + int n_mmproj_embd = llama_model_n_embd(model); + int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1; + + int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); + int32_t i_batch = 0; + int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd); + + const int nx = mtmd_image_tokens_get_nx(image_tokens); + const int ny = mtmd_image_tokens_get_ny(image_tokens); + + if (mtmd_decode_use_mrope(ctx)) { + batch_embd.set_position_mrope(n_past, nx, ny, seq_id); + } else { + batch_embd.set_position_normal(n_past, seq_id); + } + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, false); + // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image + } + + while (i_batch < n_img_batches) { // split into batches + int pos_offset = i_batch*n_batch; + int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); + llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); + + LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); + + int64_t t1 = ggml_time_ms(); + int32_t ret = llama_decode(lctx, batch_embd_view); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_set_causal_attn(lctx, true); // restore causal attn + return ret; + } + + LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); + + i_batch++; + } + + n_past += mtmd_image_tokens_get_n_pos(image_tokens); + *new_n_past = n_past; + + if (mtmd_decode_use_non_causal(ctx)) { + llama_set_causal_attn(lctx, true); + } + return 0; +} + +int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunk * chunk, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { + int32_t ret; + llama_batch text_batch = llama_batch_init(n_batch, 0, 1); + auto chunk_type = mtmd_input_chunk_get_type(chunk); + + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + size_t n_tokens; + const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + // LOG_INF("decoding text chunk, n_tokens = %zu\n", n_tokens); + size_t i = 0; + while (i < n_tokens) { // split into batches + text_batch.n_tokens = 0; // clear the batch + for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { + text_batch.n_tokens++; + text_batch.token [i] = tokens[i]; + text_batch.pos [i] = n_past++; + text_batch.n_seq_id[i] = 1; + text_batch.seq_id [i][0] = seq_id; + text_batch.logits [i] = false; + } + bool is_last_token = (i == n_tokens); + if (logits_last && is_last_token) { + text_batch.logits[text_batch.n_tokens - 1] = true; + } + ret = llama_decode(lctx, text_batch); + if (ret != 0) { + LOG_ERR("failed to decode text\n"); + llama_batch_free(text_batch); + return ret; + } + *new_n_past += text_batch.n_tokens; + } + + } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); + int64_t t0 = ggml_time_ms(); + + LOG_INF("encoding image or slice...\n"); + + ret = mtmd_encode(ctx, image_tokens); + if (ret != 0) { + LOG_ERR("failed to encode image\n"); + llama_batch_free(text_batch); + return ret; + } + + LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + + float * embd = mtmd_get_output_embd(ctx); + ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + if (ret != 0) { + LOG_ERR("failed to decode image\n"); + llama_batch_free(text_batch); + return ret; + } + } else { + GGML_ABORT("chunk type not supported"); + } + + return 0; +} + +int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, + struct llama_context * lctx, + const mtmd_input_chunks * chunks, + llama_pos n_past, + llama_seq_id seq_id, + int32_t n_batch, + bool logits_last, + llama_pos * new_n_past) { + size_t n_chunks = mtmd_input_chunks_size(chunks); + if (n_chunks == 0) { + LOG_ERR("no chunks to eval\n"); + return 0; + } + + for (size_t i = 0; i < n_chunks; i++) { + bool chunk_logits_last = (i == n_chunks - 1) && logits_last; + auto chunk = mtmd_input_chunks_get(chunks, i); + + int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); + if (res != 0) { + LOG_ERR("failed to eval chunk %zu\n", i); + return res; + } + *new_n_past = n_past; + } + + return 0; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f1b957394..2a852d9c1 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -461,307 +461,26 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { - size_t n_tokens = 0; - for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { - auto chunk = mtmd_input_chunks_get(chunks, i); - auto chunk_type = mtmd_input_chunk_get_type(chunk); - if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens_text; - mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); - n_tokens += n_tokens_text; - } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); - n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); - } else { - GGML_ASSERT(false && "chunk type not supported"); - } +bool mtmd_decode_use_non_causal(mtmd_context * ctx) { + projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); + if (proj_type == PROJECTOR_TYPE_GEMMA3) { + return true; } - return n_tokens; + return false; } -llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { - llama_pos n_pos = 0; - for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { - auto chunk = mtmd_input_chunks_get(chunks, i); - auto chunk_type = mtmd_input_chunk_get_type(chunk); - if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens_text; - mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); - n_pos += n_tokens_text; - } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); - n_pos += mtmd_image_tokens_get_n_pos(tokens_image); - } else { - GGML_ASSERT(false && "chunk type not supported"); - } - } - return n_pos; +bool mtmd_decode_use_mrope(mtmd_context * ctx) { + return ctx->use_mrope; } -// helper struct to make working with embd batch easier -// note: this will be removed after llama_batch_ext refactoring -struct decode_embd_batch { - int n_pos_per_embd; - int n_mmproj_embd; - std::vector pos; - std::vector pos_view; // used by mrope - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) { - pos .resize(n_tokens * n_pos_per_embd); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - } - - void set_position_normal(llama_pos pos_0, llama_seq_id seq_id) { - seq_id_0[0] = seq_id; - for (int i = 0; i < batch.n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } - - void set_position_mrope(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) { - GGML_ASSERT(n_pos_per_embd == 4); - seq_id_0[0] = seq_id; - for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x++) { - int i = y * nx + x; - pos[i ] = pos_0; - pos[i + batch.n_tokens ] = pos_0 + y; - pos[i + batch.n_tokens * 2] = pos_0 + x; - pos[i + batch.n_tokens * 3] = 0; // last pos dim is unused - } - } - for (int i = 0; i < batch.n_tokens; i++) { - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } - - llama_batch get_view(int offset, int n_tokens) { - llama_pos * pos_ptr; - pos_view.clear(); - pos_view.reserve(n_tokens * n_pos_per_embd); - if (n_pos_per_embd > 1) { - // mrope - // for example, with layout of src: 1234...1234...1234...1234... - // offset 2 will give us dst: 34...34...34...34... - for (int i = 0; i < n_pos_per_embd; i++) { - // assume n_tokens is less than or equal to batch.n_tokens - // batch.n_tokens is number of **total** tokens - // n_tokens is number of viewed token - size_t src_idx = i * batch.n_tokens + offset; - pos_view.insert(pos_view.end(), - pos.data() + src_idx, - pos.data() + src_idx + n_tokens); - } - pos_ptr = pos_view.data(); - } else { - // normal - pos_ptr = pos.data() + offset; - } - return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ batch.embd + offset * n_mmproj_embd, - /*pos =*/ pos_ptr, - /*n_seq_id =*/ batch.n_seq_id + offset, - /*seq_id =*/ batch.seq_id + offset, - /*logits =*/ batch.logits + offset, - }; - } -}; - -// Helper function for decoding an image whose embeddings have already been calculated -int32_t mtmd_helper_decode_image_chunk( - mtmd_context * ctx, - struct llama_context * lctx, - const mtmd_input_chunk * chunk, - float * encoded_embd, - llama_pos n_past, - llama_seq_id seq_id, - int32_t n_batch, - llama_pos * new_n_past) { - if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) { - LOG_ERR("failed to decode image chunk: input chunk not of image type\n"); - return -1; - } - const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); - if (!image_tokens) { - LOG_ERR("failed to decode image chunk: image tokens are null\n"); - return -1; - } - - int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); - int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1; - - int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens); - int32_t i_batch = 0; - int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; - decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd); - - const int nx = mtmd_image_tokens_get_nx(image_tokens); - const int ny = mtmd_image_tokens_get_ny(image_tokens); - - if (mtmd_decode_use_mrope(ctx)) { - batch_embd.set_position_mrope(n_past, nx, ny, seq_id); - } else { - batch_embd.set_position_normal(n_past, seq_id); - } - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, false); - // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image - } - - while (i_batch < n_img_batches) { // split into batches - int pos_offset = i_batch*n_batch; - int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset); - llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch); - - LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch); - - int64_t t1 = ggml_time_ms(); - int32_t ret = llama_decode(lctx, batch_embd_view); - if (ret != 0) { - LOG_ERR("failed to decode image\n"); - llama_set_causal_attn(lctx, true); // restore causal attn - return ret; - } - - if (ctx->print_timings) { - LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1); - } - - i_batch++; - } - - n_past += mtmd_image_tokens_get_n_pos(image_tokens); - *new_n_past = n_past; - - if (mtmd_decode_use_non_causal(ctx)) { - llama_set_causal_attn(lctx, true); - } - return 0; +void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { + mtmd_image_tokens_free(val); } -int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, - struct llama_context * lctx, - const mtmd_input_chunk * chunk, - llama_pos n_past, - llama_seq_id seq_id, - int32_t n_batch, - bool logits_last, - llama_pos * new_n_past) { - int32_t ret; - llama_batch text_batch = llama_batch_init(n_batch, 0, 1); - auto chunk_type = mtmd_input_chunk_get_type(chunk); - - if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - LOG_DBG("decoding text chunk, n_tokens = %zu\n", n_tokens); - size_t i = 0; - while (i < n_tokens) { // split into batches - text_batch.n_tokens = 0; // clear the batch - for (; i < n_tokens && text_batch.n_tokens < n_batch; i++) { - text_batch.n_tokens++; - text_batch.token [i] = tokens[i]; - text_batch.pos [i] = n_past++; - text_batch.n_seq_id[i] = 1; - text_batch.seq_id [i][0] = seq_id; - text_batch.logits [i] = false; - } - bool is_last_token = (i == n_tokens); - if (logits_last && is_last_token) { - text_batch.logits[text_batch.n_tokens - 1] = true; - } - ret = llama_decode(lctx, text_batch); - if (ret != 0) { - LOG_ERR("failed to decode text\n"); - llama_batch_free(text_batch); - return ret; - } - *new_n_past += text_batch.n_tokens; - } - - } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk); - int64_t t0 = ggml_time_ms(); - if (ctx->print_timings) { - LOG_INF("encoding image or slice...\n"); - } - ret = mtmd_encode(ctx, image_tokens); - if (ret != 0) { - LOG_ERR("failed to encode image\n"); - llama_batch_free(text_batch); - return ret; - } - if (ctx->print_timings) { - LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - } - float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); - if (ret != 0) { - LOG_ERR("failed to decode image\n"); - llama_batch_free(text_batch); - return ret; - } - } else { - GGML_ABORT("chunk type not supported"); - } - - return 0; -} - -int32_t mtmd_helper_eval_chunks(mtmd_context * ctx, - struct llama_context * lctx, - const mtmd_input_chunks * chunks, - llama_pos n_past, - llama_seq_id seq_id, - int32_t n_batch, - bool logits_last, - llama_pos * new_n_past) { - size_t n_chunks = mtmd_input_chunks_size(chunks); - if (n_chunks == 0) { - LOG_WRN("no chunks to eval\n"); - return 0; - } - - for (size_t i = 0; i < n_chunks; i++) { - bool chunk_logits_last = (i == n_chunks - 1) && logits_last; - auto chunk = mtmd_input_chunks_get(chunks, i); - - int32_t res = mtmd_helper_eval_chunk_single(ctx, lctx, chunk, n_past, seq_id, n_batch, chunk_logits_last, &n_past); - if (res != 0) { - LOG_ERR("failed to eval chunk %zu\n", i); - return res; - } - *new_n_past = n_past; - } - - return 0; -} +// these 2 helpers below use internal clip_image_u8_ptr, +// so unfortunately they cannot moved to mtmd-helper.h +// however, in theory, user can decode image file to bitmap using +// whichever library they want, and then use mtmd_bitmap_init() to create bitmap mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) { clip_image_u8_ptr img_u8(clip_image_u8_init()); @@ -787,23 +506,6 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) { return mtmd_bitmap_init(nx, ny, data); } -bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); - if (proj_type == PROJECTOR_TYPE_GEMMA3) { - return true; - } - return false; -} - -bool mtmd_decode_use_mrope(mtmd_context * ctx) { - return ctx->use_mrope; -} - -void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) { - mtmd_image_tokens_free(val); -} - - // // public API functions // diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 54cf481b6..0ada78c90 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -10,6 +10,7 @@ #include #ifdef __cplusplus +#include #include #include #include From 3eac209319a6726fd9687c6188fc6b916b65953d Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Sun, 11 May 2025 11:35:52 +0200 Subject: [PATCH 19/33] mtmd : support InternVL 3 38B and 78B mmproj (#13443) * Support InternVL 3 38B and 78B mmproj * Swap norms in clip.cpp * Group variables together --- gguf-py/gguf/constants.py | 6 ++++++ gguf-py/gguf/tensor_mapping.py | 8 ++++++++ tools/mtmd/clip-impl.h | 2 ++ tools/mtmd/clip.cpp | 15 +++++++++++++++ 4 files changed, 31 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ae5ce71ae..0e6226b90 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -483,7 +483,9 @@ class MODEL_TENSOR(IntEnum): V_ENC_EMBD_PATCH = auto() V_ENC_EMBD_POS = auto() V_ENC_ATTN_Q = auto() + V_ENC_ATTN_Q_NORM = auto() V_ENC_ATTN_K = auto() + V_ENC_ATTN_K_NORM = auto() V_ENC_ATTN_V = auto() V_ENC_INPUT_NORM = auto() V_ENC_OUTPUT = auto() @@ -742,7 +744,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm", MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out", @@ -782,7 +786,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_POS, MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_Q_NORM, MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_K_NORM, MODEL_TENSOR.V_ENC_ATTN_V, MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index bf7ec3257..ecf21b2b4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -938,6 +938,10 @@ class TensorNameMap: "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL + ), + MODEL_TENSOR.V_ENC_ATTN_K: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vpm.encoder.layers.{bid}.self_attn.k_proj", @@ -946,6 +950,10 @@ class TensorNameMap: "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), + MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL + ), + MODEL_TENSOR.V_ENC_ATTN_V: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vpm.encoder.layers.{bid}.self_attn.v_proj", diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index d7b788bf9..23036ba72 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -53,6 +53,8 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" +#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 0ebe81b07..735dfe7f7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -205,6 +205,9 @@ struct clip_layer { ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; + ggml_tensor * k_norm = nullptr; + ggml_tensor * q_norm = nullptr; + // layernorm 1 ggml_tensor * ln_1_w = nullptr; ggml_tensor * ln_1_b = nullptr; @@ -1363,6 +1366,16 @@ private: Vcur = ggml_add(ctx0, Vcur, layer.v_b); } + if (layer.q_norm) { + Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); + cb(Qcur, "Qcur_norm", il); + } + + if (layer.k_norm) { + Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); + cb(Kcur, "Kcur_norm", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); @@ -1988,6 +2001,8 @@ struct clip_model_loader { layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight")); layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); + layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false); + layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false); layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias From 7f323a589f8684c0eb722e7309074cb5eac0c8b5 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Sun, 11 May 2025 20:18:39 +0800 Subject: [PATCH 20/33] Add `--no-op-offload` to improve `-ot` pp perf in MoE models like llama4 400B (#13386) --- common/arg.cpp | 7 +++++++ common/common.cpp | 1 + common/common.h | 1 + ggml/include/ggml-backend.h | 4 ++-- ggml/src/ggml-backend.cpp | 8 +++++-- include/llama.h | 1 + src/llama-context.cpp | 4 +++- src/llama-cparams.h | 1 + tests/test-opt.cpp | 2 +- tools/llama-bench/llama-bench.cpp | 35 +++++++++++++++++++++++++++++-- tools/mtmd/clip.cpp | 2 +- 11 files changed, 57 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index e0f1d998f..a1fd4c965 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2437,6 +2437,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } )); + add_opt(common_arg( + {"--no-op-offload"}, + string_format("disable offloading host tensor operations to device (default: %s)", params.no_op_offload ? "true" : "false"), + [](common_params & params) { + params.no_op_offload = true; + } + )); add_opt(common_arg( {"--lora"}, "FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)", diff --git a/common/common.cpp b/common/common.cpp index bd20af233..710bf1fe2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1113,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; + cparams.op_offload = !params.no_op_offload; if (params.reranking) { cparams.embeddings = true; diff --git a/common/common.h b/common/common.h index d051d4ec9..e15356b12 100644 --- a/common/common.h +++ b/common/common.h @@ -332,6 +332,7 @@ struct common_params { bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool no_op_offload = false; // globally disable offload host tensor operations to device bool single_turn = false; // single turn chat conversation diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index ea2c1a402..778927f68 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -248,7 +248,7 @@ extern "C" { // preferrably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false); + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); // initialize buffers from a max size graph (optional) reserve_graph = build_graph(sched, max_batch_size); @@ -289,7 +289,7 @@ extern "C" { typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); // Initialize a backend scheduler, backends with low index are given priority over backends with high index - GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index c36b5abfb..6f69d895f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -674,6 +674,8 @@ struct ggml_backend_sched { char * context_buffer; size_t context_buffer_size; + bool op_offload; + int debug; }; @@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { + if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); @@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, - bool parallel) { + bool parallel, + bool op_offload) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); @@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new( } sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; ggml_backend_sched_reset(sched); diff --git a/include/llama.h b/include/llama.h index 7d5f9d559..6c6d377f8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -363,6 +363,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool no_perf; // whether to measure performance timings + bool op_offload; // whether to offload host tensor operations to device }; // model quantization parameters diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fd64622b8..a12849f0e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -93,6 +93,7 @@ llama_context::llama_context( } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.op_offload = params.op_offload; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -243,7 +244,7 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); @@ -1871,6 +1872,7 @@ llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.no_perf =*/ true, + /*.op_offload =*/ true, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 30e550f02..246fa5777 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,6 +30,7 @@ struct llama_cparams { bool flash_attn; bool no_perf; bool warmup; + bool op_offload; enum llama_pooling_type pooling_type; diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index f90c92b4b..1bc160511 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -853,7 +853,7 @@ int main(void) { backends_modded.insert(backends_modded.end(), backends.begin(), backends.end()); ggml_backend_sched_t backend_sched = ggml_backend_sched_new( - backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false); + backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true); printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i])); printf(" Device description: %s\n", ggml_backend_dev_description(devs[i])); diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 078659429..5d26b506b 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -219,6 +219,7 @@ struct cmd_params { std::vector> tensor_buft_overrides; std::vector use_mmap; std::vector embeddings; + std::vector no_op_offload; ggml_numa_strategy numa; int reps; ggml_sched_priority prio; @@ -253,6 +254,7 @@ static const cmd_params cmd_params_defaults = { /* tensor_buft_overrides*/ { std::vector{{nullptr,nullptr}} }, /* use_mmap */ { true }, /* embeddings */ { false }, + /* no_op_offload */ { false }, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -311,6 +313,7 @@ static void print_usage(int /* argc */, char ** argv) { join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); printf(" -ot --override-tensors =;... (default: disabled)\n"); + printf(" -nopo, --no-op-offload (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); @@ -588,6 +591,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.embeddings.insert(params.embeddings.end(), p.begin(), p.end()); + } else if (arg == "-nopo" || arg == "--no-op-offload") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end()); } else if (arg == "-ts" || arg == "--tensor-split") { if (++i >= argc) { invalid_param = true; @@ -794,6 +804,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } + if (params.no_op_offload.empty()) { + params.no_op_offload = cmd_params_defaults.no_op_offload; + } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } @@ -833,6 +846,7 @@ struct cmd_params_instance { std::vector tensor_buft_overrides; bool use_mmap; bool embeddings; + bool no_op_offload; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -902,6 +916,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; + cparams.op_offload = !no_op_offload; return cparams; } @@ -921,6 +936,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & ot : params.tensor_buft_overrides) for (const auto & mmp : params.use_mmap) for (const auto & embd : params.embeddings) + for (const auto & nopo : params.no_op_offload) for (const auto & nb : params.n_batch) for (const auto & nub : params.n_ubatch) for (const auto & tk : params.type_k) @@ -959,6 +975,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_buft_overrides = */ ot, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .no_op_offload= */ nopo, }; instances.push_back(instance); } @@ -990,6 +1007,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_buft_overrides = */ ot, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .no_op_offload= */ nopo, }; instances.push_back(instance); } @@ -1021,6 +1039,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_buft_overrides = */ ot, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .no_op_offload= */ nopo, }; instances.push_back(instance); } @@ -1056,6 +1075,7 @@ struct test { std::vector tensor_buft_overrides; bool use_mmap; bool embeddings; + bool no_op_offload; int n_prompt; int n_gen; int n_depth; @@ -1089,6 +1109,7 @@ struct test { tensor_buft_overrides = inst.tensor_buft_overrides; use_mmap = inst.use_mmap; embeddings = inst.embeddings; + no_op_offload = inst.no_op_offload; n_prompt = inst.n_prompt; n_gen = inst.n_gen; n_depth = inst.n_depth; @@ -1134,7 +1155,7 @@ struct test { "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", - "use_mmap", "embeddings", "n_prompt", "n_gen", "n_depth", "test_time", + "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", }; return fields; @@ -1146,7 +1167,7 @@ struct test { if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" || field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" || - field == "avg_ns" || field == "stddev_ns") { + field == "avg_ns" || field == "stddev_ns" || field == "no_op_offload") { return INT; } if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" || @@ -1222,6 +1243,7 @@ struct test { tensor_buft_overrides_str, std::to_string(use_mmap), std::to_string(embeddings), + std::to_string(no_op_offload), std::to_string(n_prompt), std::to_string(n_gen), std::to_string(n_depth), @@ -1404,6 +1426,9 @@ struct markdown_printer : public printer { if (field == "test") { return 15; } + if (field == "no_op_offload") { + return 4; + } int width = std::max((int) field.length(), 10); @@ -1435,6 +1460,9 @@ struct markdown_printer : public printer { if (field == "embeddings") { return "embd"; } + if (field == "no_op_offload") { + return "nopo"; + } if (field == "tensor_split") { return "ts"; } @@ -1503,6 +1531,9 @@ struct markdown_printer : public printer { if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) { fields.emplace_back("embeddings"); } + if (params.no_op_offload.size() > 1 || params.no_op_offload != cmd_params_defaults.no_op_offload) { + fields.emplace_back("no_op_offload"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 735dfe7f7..3f11c301a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -383,7 +383,7 @@ struct clip_ctx { backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); sched.reset( - ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false) + ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) ); } From 7474e00b34629e9cd8b06bc87ad935584ea30f8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 11 May 2025 16:09:33 +0200 Subject: [PATCH 21/33] CUDA: fix crash with partial offloading of MoE (#13439) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++-- ggml/src/ggml-cuda/mmq.cu | 4 ++-- ggml/src/ggml-cuda/mmvq.cu | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7643c4b8b..b4b85abcd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1909,13 +1909,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); + // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. + // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data. + // Therefore, in such cases use cuBLAS. + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE + && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool any_gpus_with_slow_fp16 = false; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index b4962e6a5..e1cf843de 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -91,11 +91,11 @@ void ggml_cuda_mul_mat_q( // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { - GGML_ASSERT(ggml_is_contiguously_allocated(src0)); - GGML_ASSERT(!src0->view_src); const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); } } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3b313ea29..dc7adf509 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -515,11 +515,11 @@ void ggml_cuda_mul_mat_vec_q( // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { - GGML_ASSERT(ggml_is_contiguously_allocated(src0)); - GGML_ASSERT(!src0->view_src); const size_t size_data = ggml_nbytes(src0); const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); } } From 09232370fc6426aa5dd9be01a8271b9c28f5af3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 11 May 2025 16:20:39 +0200 Subject: [PATCH 22/33] scripts : exit compare-llama-bench.py gracefully when there's nothing to compare (#13451) --- scripts/compare-llama-bench.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 8c599cf9e..c32b449f7 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -318,7 +318,7 @@ else: show = [] # Show CPU and/or GPU by default even if the hardware for all results is the same: - if "n_gpu_layers" not in properties_different: + if rows_full and "n_gpu_layers" not in properties_different: ngl = int(rows_full[0][KEY_PROPERTIES.index("n_gpu_layers")]) if ngl != 99 and "cpu_info" not in properties_different: @@ -338,6 +338,10 @@ else: pass rows_show = get_rows(show) +if not rows_show: + logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n") + sys.exit(1) + table = [] for row in rows_show: n_prompt = int(row[-5]) From 9a390c4829cd3058d26a2e2c09d16e3fd12bf1b1 Mon Sep 17 00:00:00 2001 From: Anthony Umfer Date: Sun, 11 May 2025 11:08:26 -0400 Subject: [PATCH 23/33] tools : fix uninitialized llama_batch in server (#13436) * add constructor to initialize server_context::batch, preventing destructor's call to llama_batch_free from causing an invalid free() * Update tools/server/server.cpp Co-authored-by: Xuan-Son Nguyen * use C++11 initializer syntax * switch from Copy-list-initialization to Direct-list-initialization --------- Co-authored-by: Xuan-Son Nguyen --- tools/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de8ded71f..7169ffdce 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1862,7 +1862,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch batch; + llama_batch batch {}; bool clean_kv_cache = true; bool add_bos_token = true; From c104023994d36a8e791fc6a43789b84fd552cefc Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Mon, 12 May 2025 00:39:06 +0200 Subject: [PATCH 24/33] mtmd : Use RMS norm for InternVL 3 38B and 78B mmproj (#13459) --- tools/mtmd/clip.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3f11c301a..0adf03163 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -879,9 +879,15 @@ struct clip_graph { // add CLS token inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + // The larger models use a different ViT, which uses RMS norm instead of layer norm + // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 + norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45) + ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B) + : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models) + ggml_tensor * cur = build_vit( inp, n_pos, - NORM_TYPE_NORMAL, + norm_t, hparams.ffn_op, model.position_embeddings, nullptr); From 14492144c286bbf38bb1903128403d9e2ebad54c Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Mon, 12 May 2025 06:15:32 +0100 Subject: [PATCH 25/33] enable dpcpp nightly builds with libraries (#13406) --- ggml/src/ggml-sycl/CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 6699b70ba..231fb71da 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -52,9 +52,8 @@ target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") find_package(DNNL) set(GGML_SYCL_DNNL 0) if(DNNL_FOUND) - if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR) - # Assuming oneDNN packaged with oneapi release is used which - # supports only intel target + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target set(DNNL_GPU_VENDOR "INTEL") if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") @@ -108,6 +107,9 @@ endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically # See https://github.com/uxlfoundation/oneMath/issues/654 + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(SYCL_COMPILER ON) + endif() find_package(MKL REQUIRED) target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) From df8491922f0ea4e032338186350dff2fb6b2b4e3 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 12 May 2025 10:29:13 +0200 Subject: [PATCH 26/33] ggml : add mrope kernel for metal (#13457) --- ggml/src/ggml-metal/ggml-metal-impl.h | 4 + ggml/src/ggml-metal/ggml-metal.m | 58 +++++++--- ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea8cf4586..17eab976f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -207,6 +207,10 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 943f0fe72..576f9581b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -332,6 +332,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, @@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); @@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } + return true; case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: @@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node( } break; case GGML_OP_ROPE: { + // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); GGML_ASSERT(ne10 >= ne02); @@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node( memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + // mrope + const int sect_0 = ((const int32_t *) dst->op_params)[11]; + const int sect_1 = ((const int32_t *) dst->op_params)[12]; + const int sect_2 = ((const int32_t *) dst->op_params)[13]; + const int sect_3 = ((const int32_t *) dst->op_params)[14]; id pipeline = nil; - if (!is_neox) { + if (is_neox) { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } @@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node( /*.attn_factor =*/ attn_factor, /*.beta_fast =*/ beta_fast, /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a808e79c2..9cfddf450 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox( } } +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + // mrope theta calculations + // note: the rest is the same as kernel_rope_neox + const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; + const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + + float theta_base; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < 2*args.n_dims) { // different from kernel_rope_multi + const int ic = i0/2; + + // mrope theta calculations (only support 2 dimensions) + const int sect_dims = args.sect_0 + args.sect_1; + const int sector = ic % sect_dims; + + float p; + float theta_base; + if (sector < args.sect_1) { + p = (float) sector; + theta_base = (float) pos[i2]; + } else { + p = (float) sector - args.sect_0; + theta_base = (float) pos[i2 + args.ne02]; + } + + const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); + // end of mrope + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; // different from kernel_rope_multi + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst, From 95e18884fc7ea4031f70f1a518d5d1df616e5717 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 12 May 2025 10:51:21 +0200 Subject: [PATCH 27/33] CUDA: fix misaligned synchronization in FA (#13469) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 9873ea755..491780abd 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -895,6 +895,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } + } else if (np > 1) { + // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. + // Therefore, all other warps also need to execute a __syncthreads(). + // Otherwise the points at which warps synchronize with each other would become misaligned. + __syncthreads(); } #pragma unroll From a71a4075cdb8e81eaaa834e48ffda88b038286bc Mon Sep 17 00:00:00 2001 From: Dan Johansson Date: Mon, 12 May 2025 13:06:19 +0200 Subject: [PATCH 28/33] ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (#13053) * ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel Signed-off-by: Dan Johansson * * code review fixes Signed-off-by: Dan Johansson * * adds a comment that clarifies barrier usage Signed-off-by: Dan Johansson --------- Signed-off-by: Dan Johansson Co-authored-by: Charles Xu --- ggml/src/ggml-cpu/CMakeLists.txt | 29 +- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 93 ++++++- ggml/src/ggml-cpu/kleidiai/kernels.h | 58 +++- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 335 +++++++++++++++++++----- 4 files changed, 416 insertions(+), 99 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 9a3085bef..bdaec2881 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) - set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) + set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) if (NOT DOTPROD_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) endif() if (NOT I8MM_ENABLED MATCHES -1) @@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if (NOT SME_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) - set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index aacc2bb5e..910fd0ee4 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -4,16 +4,22 @@ // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" + +#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" + #include "kai_common.h" #include "kernels.h" @@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, }, /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_F16, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__APPLE__) @@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_MATMUL_INT8) @@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #else @@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_DOTPROD) @@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #endif }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { + ggml_kleidiai_kernels * kernel = nullptr; + + if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && + gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && + gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && + gemm_gemv_kernels[i].op_type == tensor->type) { + kernel = &gemm_gemv_kernels[i]; + break; + } + } + } + + return kernel; +} + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ffe97eb4..5ac02bda7 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,6 +4,9 @@ #pragma once +#include +#include "ggml.h" + enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, @@ -26,26 +29,53 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); - size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + std::variant< + std::function, + std::function + > get_lhs_offset; + std::variant< + std::function, + std::function + > get_rhs_packed_offset; size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + std::variant< + std::function, + std::function + > run_kernel; }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, - size_t lhs_stride, void* lhs_packed); + std::variant< + std::function, + std::function + > get_packed_offset; + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct rhs_packing_info { - size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); - void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct ggml_kleidiai_kernels { @@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels { rhs_packing_info rhs_info; cpu_feature required_cpu; + ggml_type lhs_type; + ggml_type rhs_type; + ggml_type op_type; }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 4e89ca0fa..f3dffdd6b 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -34,8 +34,9 @@ #include "ggml-common.h" struct ggml_kleidiai_context { + cpu_feature features; ggml_kleidiai_kernels * kernels; -} static ctx = { NULL }; +} static ctx = { CPU_FEATURE_NONE, NULL }; static void init_kleidiai_context(void) { @@ -47,18 +48,18 @@ static void init_kleidiai_context(void) { const char *env_var = getenv("GGML_KLEIDIAI_SME"); int sme_enabled = 0; - cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | - (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); } if (sme_enabled != 0) { - features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } - ctx.kernels = ggml_kleidiai_select_kernels(features); + ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); } ggml_critical_section_end(); } @@ -68,95 +69,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +static Ret variant_call(const Variant & var, Args&&... args) { + return std::visit([&](auto&& func) -> Ret { + if constexpr (std::is_invocable_r_v) { + return func(std::forward(args)...); + } else { + throw std::runtime_error("Invalid function type in variant_call"); + } + }, var); +} + namespace ggml::cpu::kleidiai { + +static size_t round_down(size_t x, size_t y) { + return y == 0 ? x : x - (x % y); +} + +static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) { + size_t src_stride = rhs_stride / sizeof(uint16_t); + size_t dst_stride = n; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + uint16_t v = *(src + k_idx + n_idx * src_stride); + *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v); + } + } +} + class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); + GGML_ASSERT(kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; size_t k = op->src[0]->ne[0]; + size_t n = op->src[0]->ne[1]; size_t m = op->src[1]->ne[1]; size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + if (kernels->rhs_type == GGML_TYPE_Q4_0) { + size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); + } else if (kernels->rhs_type == GGML_TYPE_F16) { + size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + + variant_call(kernels->rhs_info.packed_size, n, k) + + k * n * sizeof(float) + n * sizeof(float); + } else { + GGML_ASSERT(false); + } return true; } + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + return compute_forward_q4_0(params, dst); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + return compute_forward_kv_cache(params, dst); + } + } + return false; + } - GGML_TENSOR_BINARY_OP_LOCALS + bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; - lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(kernel); + GGML_TENSOR_BINARY_OP_LOCALS - const int ith = params->ith; - const int nth = params->nth; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + GGML_ASSERT(kernel); - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + const int nth = params->nth; + const int ith = params->ith; - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + const int64_t lhs_batch_size0 = ne12; + const int64_t rhs_batch_size0 = ne02; + const int64_t batch_size = rhs_batch_size0; + + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + + const int64_t m = ne11 * r; + const int64_t n = ne01; + const int64_t k = ne00; + + const size_t lhs_stride = src1->nb[1]; + const size_t rhs_stride = src0->nb[1]; + const size_t dst_stride = dst->nb[1]; + + const int64_t mr = static_cast(kernel->get_mr()); + const int64_t nr = static_cast(kernel->get_nr()); + const int64_t kr = static_cast(kernel->get_kr()); + const int64_t sr = static_cast(kernel->get_sr()); + + const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); + + const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; + GGML_ASSERT(wsize_required <= params->wsize); + + uint8_t * lhs_packed = static_cast(params->wdata); + uint8_t * rhs_packed = lhs_packed + lhs_packed_size; + uint8_t * rhs_kxn = rhs_packed + rhs_packed_size; + uint8_t * bias = rhs_kxn + kxn_size; + + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; + const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; + uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; + + // LHS packing + { + const int64_t m_roundup_mr = kai_roundup(m, mr); + const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); + + if (ith < num_threads) { + const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; + + const int64_t m_start = ith * num_m_per_thread0; + const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); + const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); + + const void * src_ptr = static_cast(lhs_batch) + lhs_offset; + void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + + variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + } } - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); + // RHS packing + if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { + // First thread to reach this point handles RHS packing + memset(bias, 0, n * sizeof(float)); + transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch), rhs_stride); - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; - } - - if(m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - - lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + first_to_arrive.clear(std::memory_order_release); - kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - return true; + // Perform the matmul + { + const int64_t m_to_process = m; + const int64_t m_start = 0; + + const int64_t n_step = static_cast(kernel->get_n_step()); + const int64_t num_threads = KAI_MIN(n / n_step, nth); + + if (ith < num_threads) { + const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + + const int64_t n_start = ith * num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + + const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); + const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + + const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * rhs_ptr = rhs_packed + rhs_packed_offset; + float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + + variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + } + + if (batch_idx != batch_size - 1) { + // This barrier is necessary when the batch size is larger than 1. While processing a batch, + // the work data buffer (params->wdata) is used as temporary storage which means that only + // a single batch can be processed at any given time. No barrier is needed for the last + // batch since GGML inserts a barrier between the execution of every operator. + ggml_barrier(params->threadpool); + } } - return false; + + return true; + } + + bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); + + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = &kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + // Calculate number of columns to be processed per thread + const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if (m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + + return true; } public: @@ -169,13 +350,13 @@ public: size_t sr = ctx.kernels->gemm.get_sr(); #ifndef NDEBUG - const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); #endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; @@ -189,7 +370,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc } } // namespace ggml::cpu::kleidiai -GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); @@ -238,12 +419,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ( op->op == GGML_OP_MUL_MAT && - op->src[0]->type == GGML_TYPE_Q4_0 && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels - ) { + if (op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -260,6 +440,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } + else if (ggml_kleidiai_select_kernels(ctx.features, op) && + op->src[0]->op == GGML_OP_VIEW && + (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && + op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[0] != 2) || + (op->src[1]->nb[0] != 4) || + (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + } } return nullptr; } From 22cdab343b63edd0906f1c132616746085f81983 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Mon, 12 May 2025 13:08:22 +0200 Subject: [PATCH 29/33] llama-bench : accept ranges for integer parameters (#13410) --- tools/llama-bench/README.md | 28 +- tools/llama-bench/llama-bench.cpp | 721 ++++++++++++++++-------------- 2 files changed, 404 insertions(+), 345 deletions(-) diff --git a/tools/llama-bench/README.md b/tools/llama-bench/README.md index d6fc77df8..4fb2a24e1 100644 --- a/tools/llama-bench/README.md +++ b/tools/llama-bench/README.md @@ -20,10 +20,20 @@ Performance testing tool for llama.cpp. ## Syntax ``` -usage: ./llama-bench [options] +usage: llama-bench [options] options: -h, --help + --numa numa mode (default: disabled) + -r, --repetitions number of times to repeat each test (default: 5) + --prio <0|1|2|3> process/thread priority (default: 0) + --delay <0...N> (seconds) delay between each test (default: 0) + -o, --output output format printed to stdout (default: md) + -oe, --output-err output format printed to stderr (default: none) + -v, --verbose verbose output + --progress print test progress indicators + +test parameters: -m, --model (default: models/7B/ggml-model-q4_0.gguf) -p, --n-prompt (default: 512) -n, --n-gen (default: 128) @@ -33,7 +43,7 @@ options: -ub, --ubatch-size (default: 512) -ctk, --cache-type-k (default: f16) -ctv, --cache-type-v (default: f16) - -t, --threads (default: 8) + -t, --threads (default: 16) -C, --cpu-mask (default: 0x0) --cpu-strict <0|1> (default: 0) --poll <0...100> (default: 50) @@ -44,17 +54,15 @@ options: -nkvo, --no-kv-offload <0|1> (default: 0) -fa, --flash-attn <0|1> (default: 0) -mmp, --mmap <0|1> (default: 1) - --numa (default: disabled) -embd, --embeddings <0|1> (default: 0) -ts, --tensor-split (default: 0) - -r, --repetitions (default: 5) - --prio <0|1|2|3> (default: 0) - --delay <0...N> (seconds) (default: 0) - -o, --output (default: md) - -oe, --output-err (default: none) - -v, --verbose (default: 0) + -ot --override-tensors =;... + (default: disabled) + -nopo, --no-op-offload <0|1> (default: 0) -Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times. +Multiple values can be given for each parameter by separating them with ',' +or by specifying the parameter multiple times. Ranges can be given as +'start-end' or 'start-end+step' or 'start-end*mult'. ``` llama-bench can perform three types of tests: diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 5d26b506b..ca0d0aed5 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -195,6 +195,40 @@ static std::string pair_str(const std::pair & p) { return buf; } +static std::vector parse_int_range(const std::string & s) { + // first[-last[(+|*)step]] + std::regex range_regex(R"(^(\d+)(?:-(\d+)(?:([\+|\*])(\d+))?)?(?:,|$))"); + + std::smatch match; + std::string::const_iterator search_start(s.cbegin()); + std::vector result; + while (std::regex_search(search_start, s.cend(), match, range_regex)) { + int first = std::stoi(match[1]); + int last = match[2].matched ? std::stoi(match[2]) : first; + char op = match[3].matched ? match[3].str()[0] : '+'; + int step = match[4].matched ? std::stoi(match[4]) : 1; + + for (int i = first; i <= last;) { + result.push_back(i); + + if (op == '+') { + i += step; + } else if (op == '*') { + i *= step; + } else { + throw std::invalid_argument("invalid range format"); + } + } + search_start = match.suffix().first; + } + + if (search_start != s.cend()) { + throw std::invalid_argument("invalid range format"); + } + + return result; +} + struct cmd_params { std::vector model; std::vector n_prompt; @@ -251,7 +285,7 @@ static const cmd_params cmd_params_defaults = { /* no_kv_offload */ { false }, /* flash_attn */ { false }, /* tensor_split */ { std::vector(llama_max_devices(), 0.0f) }, - /* tensor_buft_overrides*/ { std::vector{{nullptr,nullptr}} }, + /* tensor_buft_overrides*/ { std::vector{ { nullptr, nullptr } } }, /* use_mmap */ { true }, /* embeddings */ { false }, /* no_op_offload */ { false }, @@ -270,13 +304,29 @@ static void print_usage(int /* argc */, char ** argv) { printf("\n"); printf("options:\n"); printf(" -h, --help\n"); + printf(" --numa numa mode (default: disabled)\n"); + printf(" -r, --repetitions number of times to repeat each test (default: %d)\n", + cmd_params_defaults.reps); + printf(" --prio <0|1|2|3> process/thread priority (default: %d)\n", + cmd_params_defaults.prio); + printf(" --delay <0...N> (seconds) delay between each test (default: %d)\n", + cmd_params_defaults.delay); + printf(" -o, --output output format printed to stdout (default: %s)\n", + output_format_str(cmd_params_defaults.output_format)); + printf(" -oe, --output-err output format printed to stderr (default: %s)\n", + output_format_str(cmd_params_defaults.output_format_stderr)); + printf(" -v, --verbose verbose output\n"); + printf(" --progress print test progress indicators\n"); + printf("\n"); + printf("test parameters:\n"); printf(" -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); - printf(" -d, --n-depth (default: %s)\n", join(cmd_params_defaults.n_depth, ",").c_str()); + printf(" -d, --n-depth (default: %s)\n", + join(cmd_params_defaults.n_depth, ",").c_str()); printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); printf(" -ub, --ubatch-size (default: %s)\n", @@ -308,25 +358,17 @@ static void print_usage(int /* argc */, char ** argv) { join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); - printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); - printf(" -ot --override-tensors =;... (default: disabled)\n"); - printf(" -nopo, --no-op-offload (default: 0)\n"); - printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); - printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); - printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); - printf(" -o, --output (default: %s)\n", - output_format_str(cmd_params_defaults.output_format)); - printf(" -oe, --output-err (default: %s)\n", - output_format_str(cmd_params_defaults.output_format_stderr)); - printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); - printf(" --progress (default: %s)\n", cmd_params_defaults.progress ? "1" : "0"); + printf(" -ot --override-tensors =;...\n"); + printf(" (default: disabled)\n"); + printf(" -nopo, --no-op-offload <0|1> (default: 0)\n"); printf("\n"); printf( - "Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter " - "multiple times.\n"); + "Multiple values can be given for each parameter by separating them with ','\n" + "or by specifying the parameter multiple times. Ranges can be given as\n" + "'start-end' or 'start-end+step' or 'start-end*mult'.\n"); } static ggml_type ggml_type_from_name(const std::string & s) { @@ -380,186 +422,190 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { std::replace(arg.begin(), arg.end(), '_', '-'); } - if (arg == "-h" || arg == "--help") { - print_usage(argc, argv); - exit(0); - } else if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.model.insert(params.model.end(), p.begin(), p.end()); - } else if (arg == "-p" || arg == "--n-prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end()); - } else if (arg == "-n" || arg == "--n-gen") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_gen.insert(params.n_gen.end(), p.begin(), p.end()); - } else if (arg == "-pg") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], ','); - if (p.size() != 2) { - invalid_param = true; - break; - } - params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) }); - } else if (arg == "-d" || arg == "--n-depth") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_depth.insert(params.n_depth.end(), p.begin(), p.end()); - } else if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_batch.insert(params.n_batch.end(), p.begin(), p.end()); - } else if (arg == "-ub" || arg == "--ubatch-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end()); - } else if (arg == "-ctk" || arg == "--cache-type-k") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - std::vector types; - for (const auto & t : p) { - ggml_type gt = ggml_type_from_name(t); - if (gt == GGML_TYPE_COUNT) { + try { + if (arg == "-h" || arg == "--help") { + print_usage(argc, argv); + exit(0); + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { invalid_param = true; break; } - types.push_back(gt); - } - if (invalid_param) { - break; - } - params.type_k.insert(params.type_k.end(), types.begin(), types.end()); - } else if (arg == "-ctv" || arg == "--cache-type-v") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - std::vector types; - for (const auto & t : p) { - ggml_type gt = ggml_type_from_name(t); - if (gt == GGML_TYPE_COUNT) { + auto p = string_split(argv[i], split_delim); + params.model.insert(params.model.end(), p.begin(), p.end()); + } else if (arg == "-p" || arg == "--n-prompt") { + if (++i >= argc) { invalid_param = true; break; } - types.push_back(gt); - } - if (invalid_param) { - break; - } - params.type_v.insert(params.type_v.end(), types.begin(), types.end()); - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); - } else if (arg == "-C" || arg == "--cpu-mask") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end()); - } else if (arg == "--cpu-strict") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end()); - } else if (arg == "--poll") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.poll.insert(params.poll.end(), p.begin(), p.end()); - } else if (arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); - } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rpc_servers.push_back(argv[i]); - } else if (arg == "-sm" || arg == "--split-mode") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - std::vector modes; - for (const auto & m : p) { - llama_split_mode mode; - if (m == "none") { - mode = LLAMA_SPLIT_MODE_NONE; - } else if (m == "layer") { - mode = LLAMA_SPLIT_MODE_LAYER; - } else if (m == "row") { - mode = LLAMA_SPLIT_MODE_ROW; - } else { + auto p = parse_int_range(argv[i]); + params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end()); + } else if (arg == "-n" || arg == "--n-gen") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_gen.insert(params.n_gen.end(), p.begin(), p.end()); + } else if (arg == "-pg") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], ','); + if (p.size() != 2) { + invalid_param = true; + break; + } + params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) }); + } else if (arg == "-d" || arg == "--n-depth") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_depth.insert(params.n_depth.end(), p.begin(), p.end()); + } else if (arg == "-b" || arg == "--batch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_batch.insert(params.n_batch.end(), p.begin(), p.end()); + } else if (arg == "-ub" || arg == "--ubatch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end()); + } else if (arg == "-ctk" || arg == "--cache-type-k") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + + std::vector types; + for (const auto & t : p) { + ggml_type gt = ggml_type_from_name(t); + if (gt == GGML_TYPE_COUNT) { + invalid_param = true; + break; + } + types.push_back(gt); + } + if (invalid_param) { + break; + } + params.type_k.insert(params.type_k.end(), types.begin(), types.end()); + } else if (arg == "-ctv" || arg == "--cache-type-v") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + + std::vector types; + for (const auto & t : p) { + ggml_type gt = ggml_type_from_name(t); + if (gt == GGML_TYPE_COUNT) { + invalid_param = true; + break; + } + types.push_back(gt); + } + if (invalid_param) { + break; + } + params.type_v.insert(params.type_v.end(), types.begin(), types.end()); + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + } else if (arg == "-C" || arg == "--cpu-mask") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end()); + } else if (arg == "--cpu-strict") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end()); + } else if (arg == "--poll") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.poll.insert(params.poll.end(), p.begin(), p.end()); + } else if (arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = parse_int_range(argv[i]); + params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); + } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rpc_servers.push_back(argv[i]); + } else if (arg == "-sm" || arg == "--split-mode") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + + std::vector modes; + for (const auto & m : p) { + llama_split_mode mode; + if (m == "none") { + mode = LLAMA_SPLIT_MODE_NONE; + } else if (m == "layer") { + mode = LLAMA_SPLIT_MODE_LAYER; + } else if (m == "row") { + mode = LLAMA_SPLIT_MODE_ROW; + } else { + invalid_param = true; + break; + } + modes.push_back(mode); + } + if (invalid_param) { + break; + } + params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end()); + } else if (arg == "-mg" || arg == "--main-gpu") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.main_gpu = parse_int_range(argv[i]); + } else if (arg == "-nkvo" || arg == "--no-kv-offload") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "--numa") { + if (++i >= argc) { invalid_param = true; break; } - modes.push_back(mode); - } - if (invalid_param) { - break; - } - params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end()); - } else if (arg == "-mg" || arg == "--main-gpu") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.main_gpu = string_split(argv[i], split_delim); - } else if (arg == "-nkvo" || arg == "--no-kv-offload") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); - } else if (arg == "--numa") { - if (++i >= argc) { - invalid_param = true; - break; - } else { std::string value(argv[i]); - /**/ if (value == "distribute" || value == "") { + if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; @@ -569,177 +615,182 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - } - } else if (arg == "-fa" || arg == "--flash-attn") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); - } else if (arg == "-mmp" || arg == "--mmap") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end()); - } else if (arg == "-embd" || arg == "--embeddings") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.embeddings.insert(params.embeddings.end(), p.begin(), p.end()); - } else if (arg == "-nopo" || arg == "--no-op-offload") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = string_split(argv[i], split_delim); - params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end()); - } else if (arg == "-ts" || arg == "--tensor-split") { - if (++i >= argc) { - invalid_param = true; - break; - } - for (auto ts : string_split(argv[i], split_delim)) { - // split string by ; and / - const std::regex regex{ R"([;/]+)" }; - std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 }; - std::vector split_arg{ it, {} }; - GGML_ASSERT(split_arg.size() <= llama_max_devices()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-mmp" || arg == "--mmap") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end()); + } else if (arg == "-embd" || arg == "--embeddings") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.embeddings.insert(params.embeddings.end(), p.begin(), p.end()); + } else if (arg == "-nopo" || arg == "--no-op-offload") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end()); + } else if (arg == "-ts" || arg == "--tensor-split") { + if (++i >= argc) { + invalid_param = true; + break; + } + for (auto ts : string_split(argv[i], split_delim)) { + // split string by ; and / + const std::regex regex{ R"([;/]+)" }; + std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + GGML_ASSERT(split_arg.size() <= llama_max_devices()); - std::vector tensor_split(llama_max_devices()); - for (size_t i = 0; i < llama_max_devices(); ++i) { - if (i < split_arg.size()) { - tensor_split[i] = std::stof(split_arg[i]); - } else { - tensor_split[i] = 0.0f; + std::vector tensor_split(llama_max_devices()); + for (size_t i = 0; i < llama_max_devices(); ++i) { + if (i < split_arg.size()) { + tensor_split[i] = std::stof(split_arg[i]); + } else { + tensor_split[i] = 0.0f; + } + } + params.tensor_split.push_back(tensor_split); + } + } else if (arg == "-ot" || arg == "--override-tensor") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto value = argv[i]; + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } } } - params.tensor_split.push_back(tensor_split); - } - } else if (arg == "-ot" || arg == "--override-tensor") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto value = argv[i]; - /* static */ std::map buft_list; - if (buft_list.empty()) { - // enumerate all the devices and add their buffer types to the list - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - auto * dev = ggml_backend_dev_get(i); - auto * buft = ggml_backend_dev_buffer_type(dev); - if (buft) { - buft_list[ggml_backend_buft_name(buft)] = buft; + auto override_group_span_len = std::strcspn(value, ","); + bool last_group = false; + do { + if (override_group_span_len == 0) { + // Adds an empty override-tensors for an empty span + params.tensor_buft_overrides.push_back({{}}); + if (value[override_group_span_len] == '\0') { + value = &value[override_group_span_len]; + last_group = true; + } else { + value = &value[override_group_span_len + 1]; + override_group_span_len = std::strcspn(value, ","); + } + continue; } - } - } - auto override_group_span_len = std::strcspn(value, ","); - bool last_group = false; - do { - if (override_group_span_len == 0) { - // Adds an empty override-tensors for an empty span - params.tensor_buft_overrides.push_back({{}}); + // Stamps null terminators into the argv + // value for this option to avoid the + // memory leak present in the implementation + // over in arg.cpp. Acceptable because we + // only parse these args once in this program. + auto override_group = value; if (value[override_group_span_len] == '\0') { value = &value[override_group_span_len]; last_group = true; } else { + value[override_group_span_len] = '\0'; value = &value[override_group_span_len + 1]; - override_group_span_len = std::strcspn(value, ","); } - continue; - } - // Stamps null terminators into the argv - // value for this option to avoid the - // memory leak present in the implementation - // over in arg.cpp. Acceptable because we - // only parse these args once in this program. - auto override_group = value; - if (value[override_group_span_len] == '\0') { - value = &value[override_group_span_len]; - last_group = true; - } else { - value[override_group_span_len] = '\0'; - value = &value[override_group_span_len + 1]; - } - std::vector group_tensor_buft_overrides{}; - auto override_span_len = std::strcspn(override_group, ";"); - while (override_span_len > 0) { - auto override = override_group; - if (override_group[override_span_len] != '\0') { - override_group[override_span_len] = '\0'; - override_group = &override_group[override_span_len + 1]; - } else { - override_group = &override_group[override_span_len]; - } - auto tensor_name_span_len = std::strcspn(override, "="); - if (tensor_name_span_len >= override_span_len) { - invalid_param = true; - break; - } - override[tensor_name_span_len] = '\0'; - auto tensor_name = override; - auto buffer_type = &override[tensor_name_span_len + 1]; - if (buft_list.find(buffer_type) == buft_list.end()) { - printf("Available buffer types:\n"); - for (const auto & it : buft_list) { - printf(" %s\n", ggml_backend_buft_name(it.second)); + std::vector group_tensor_buft_overrides{}; + auto override_span_len = std::strcspn(override_group, ";"); + while (override_span_len > 0) { + auto override = override_group; + if (override_group[override_span_len] != '\0') { + override_group[override_span_len] = '\0'; + override_group = &override_group[override_span_len + 1]; + } else { + override_group = &override_group[override_span_len]; } - invalid_param = true; + auto tensor_name_span_len = std::strcspn(override, "="); + if (tensor_name_span_len >= override_span_len) { + invalid_param = true; + break; + } + override[tensor_name_span_len] = '\0'; + auto tensor_name = override; + auto buffer_type = &override[tensor_name_span_len + 1]; + if (buft_list.find(buffer_type) == buft_list.end()) { + printf("Available buffer types:\n"); + for (const auto & it : buft_list) { + printf(" %s\n", ggml_backend_buft_name(it.second)); + } + invalid_param = true; + break; + } + group_tensor_buft_overrides.push_back({tensor_name, buft_list.at(buffer_type)}); + override_span_len = std::strcspn(override_group, ";"); + } + if (invalid_param) { break; } - group_tensor_buft_overrides.push_back({tensor_name, buft_list.at(buffer_type)}); - override_span_len = std::strcspn(override_group, ";"); - } - if (invalid_param) { + group_tensor_buft_overrides.push_back({nullptr,nullptr}); + params.tensor_buft_overrides.push_back(group_tensor_buft_overrides); + override_group_span_len = std::strcspn(value, ","); + } while (!last_group); + } else if (arg == "-r" || arg == "--repetitions") { + if (++i >= argc) { + invalid_param = true; break; } - group_tensor_buft_overrides.push_back({nullptr,nullptr}); - params.tensor_buft_overrides.push_back(group_tensor_buft_overrides); - override_group_span_len = std::strcspn(value, ","); - } while (!last_group); - } else if (arg == "-r" || arg == "--repetitions") { - if (++i >= argc) { + params.reps = std::stoi(argv[i]); + } else if (arg == "--prio") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.prio = (enum ggml_sched_priority) std::stoi(argv[i]); + } else if (arg == "--delay") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.delay = std::stoi(argv[i]); + } else if (arg == "-o" || arg == "--output") { + if (++i >= argc) { + invalid_param = true; + break; + } + invalid_param = !output_format_from_str(argv[i], params.output_format); + } else if (arg == "-oe" || arg == "--output-err") { + if (++i >= argc) { + invalid_param = true; + break; + } + invalid_param = !output_format_from_str(argv[i], params.output_format_stderr); + } else if (arg == "-v" || arg == "--verbose") { + params.verbose = true; + } else if (arg == "--progress") { + params.progress = true; + } else { invalid_param = true; break; } - params.reps = std::stoi(argv[i]); - } else if (arg == "--prio") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.prio = (enum ggml_sched_priority) std::stoi(argv[i]); - } else if (arg == "--delay") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.delay = std::stoi(argv[i]); - } else if (arg == "-o" || arg == "--output") { - if (++i >= argc) { - invalid_param = true; - break; - } - invalid_param = !output_format_from_str(argv[i], params.output_format); - } else if (arg == "-oe" || arg == "--output-err") { - if (++i >= argc) { - invalid_param = true; - break; - } - invalid_param = !output_format_from_str(argv[i], params.output_format_stderr); - } else if (arg == "-v" || arg == "--verbose") { - params.verbose = true; - } else if (arg == "--progress") { - params.progress = true; - } else { + } catch (const std::exception & e) { + fprintf(stderr, "error: %s\n", e.what()); invalid_param = true; break; } } + if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); print_usage(argc, argv); From 91159ee9df84edb72ccf874e353d2414f3a8c60d Mon Sep 17 00:00:00 2001 From: Anudit Nagar Date: Mon, 12 May 2025 18:56:42 +0700 Subject: [PATCH 30/33] server : allow content to be null in oaicompat_completion_params_parse (#13477) --- tools/server/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 23163f4fe..b8d140e3f 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -644,7 +644,7 @@ static json oaicompat_completion_params_parse( } for (auto & msg : messages) { json & content = msg.at("content"); - if (content.is_string()) { + if (content.is_string() || content.is_null()) { continue; } From 064cc596ac44308dc326a17c9e3163c34a6f29d1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 12 May 2025 15:12:27 +0300 Subject: [PATCH 31/33] context : fix state io for memory-less contexts (#13470) ggml-ci --- src/llama-context.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a12849f0e..0cb6ebc9f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1788,10 +1788,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } } - LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - llama_kv_cache * kv_self = static_cast(memory.get()); + if (memory) { + LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - kv_self->state_read(io); + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->state_read(io); + } return io.n_bytes(); } @@ -1799,9 +1802,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - llama_kv_cache * kv_self = static_cast(memory.get()); + if (memory) { + llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->state_write(io, seq_id); + kv_self->state_write(io, seq_id); + } return io.n_bytes(); } @@ -1809,9 +1814,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - llama_kv_cache * kv_self = static_cast(memory.get()); + if (memory) { + llama_kv_cache * kv_self = static_cast(memory.get()); - kv_self->state_read(io, seq_id); + kv_self->state_read(io, seq_id); + } return io.n_bytes(); } From 10d2af0eaa0aafd7c6577b279dfa5221ff44a63f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 12 May 2025 14:44:49 +0200 Subject: [PATCH 32/33] llama/ggml: add LLM training support (#10544) * llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period --- build-xcframework.sh | 1 + common/common.cpp | 17 + common/common.h | 6 + examples/CMakeLists.txt | 1 + examples/training/CMakeLists.txt | 5 + examples/training/README.md | 17 + examples/training/finetune.cpp | 96 ++++++ ggml/include/ggml-opt.h | 75 ++-- ggml/include/ggml.h | 13 +- ggml/src/ggml-backend.cpp | 2 +- ggml/src/ggml-cuda/acc.cu | 66 ++-- ggml/src/ggml-cuda/sum.cu | 2 +- ggml/src/ggml-opt.cpp | 570 ++++++++++++++++++++----------- ggml/src/ggml.c | 41 ++- include/llama.h | 36 ++ src/CMakeLists.txt | 1 + src/llama-context.cpp | 244 ++++++++++++- src/llama-context.h | 30 ++ src/llama-graph.cpp | 1 + src/llama-graph.h | 3 + src/llama-model-loader.cpp | 20 +- src/llama-model-saver.cpp | 281 +++++++++++++++ src/llama-model-saver.h | 37 ++ src/llama-model.cpp | 8 +- src/llama-model.h | 2 + src/llama-quant.cpp | 4 +- src/llama-vocab.cpp | 35 +- src/llama-vocab.h | 6 + src/llama.cpp | 9 + tests/test-backend-ops.cpp | 93 ++--- tests/test-opt.cpp | 52 +-- 31 files changed, 1415 insertions(+), 359 deletions(-) create mode 100644 examples/training/CMakeLists.txt create mode 100644 examples/training/README.md create mode 100644 examples/training/finetune.cpp create mode 100644 src/llama-model-saver.cpp create mode 100644 src/llama-model-saver.h diff --git a/build-xcframework.sh b/build-xcframework.sh index 3c2498b03..a08419a80 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -117,6 +117,7 @@ setup_framework_structure() { # Copy all required headers (common for all platforms) cp include/llama.h ${header_path} cp ggml/include/ggml.h ${header_path} + cp ggml/include/ggml-opt.h ${header_path} cp ggml/include/ggml-alloc.h ${header_path} cp ggml/include/ggml-backend.h ${header_path} cp ggml/include/ggml-metal.h ${header_path} diff --git a/common/common.cpp b/common/common.cpp index 710bf1fe2..2b1d8da59 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1565,3 +1565,20 @@ common_control_vector_data common_control_vector_load(const std::vector & tokens, int64_t stride) { + const int64_t ne_datapoint = llama_n_ctx(ctx); + const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride; + ggml_opt_dataset_t result = ggml_opt_dataset_init( + GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1); + + llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data; + llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data; + + for (int64_t idata = 0; idata < ndata; ++idata) { + memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token)); + memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token)); + } + + return result; +} diff --git a/common/common.h b/common/common.h index e15356b12..dea34267c 100644 --- a/common/common.h +++ b/common/common.h @@ -666,3 +666,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } + +// +// training utils +// + +ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4ca9230c5..49e4d2cf8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -32,6 +32,7 @@ else() add_subdirectory(speculative) add_subdirectory(speculative-simple) add_subdirectory(gen-docs) + add_subdirectory(training) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/training/CMakeLists.txt b/examples/training/CMakeLists.txt new file mode 100644 index 000000000..64afe6ddc --- /dev/null +++ b/examples/training/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-finetune) +add_executable(${TARGET} finetune.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 000000000..ecdf398f8 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,17 @@ +# llama.cpp/examples/training + +This directory contains examples related to language model training using llama.cpp/GGML. +So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. +Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. +**For CPU training, compile llama.cpp without any additional backends such as CUDA.** +**For CUDA training, use the maximum number of GPU layers.** + +Proof of concept: + +``` sh +export model_name=llama_3.2-1b && export quantization=f32 +./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 +./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf +``` + +The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp new file mode 100644 index 000000000..23bede49b --- /dev/null +++ b/examples/training/finetune.cpp @@ -0,0 +1,96 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +int main(int argc, char ** argv) { + common_params params; + + params.escape = false; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { + return 1; + } + + if (params.use_mmap) { + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); + params.use_mmap = false; + } + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model and apply lora adapter, if any + common_init_result llama_init = common_init_from_params(params); + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + constexpr float val_split = 0.05f; + + std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); + optimizer_params.adamw.alpha = 1e-7f; // learning rate + + struct llama_opt_params lopt_params { + /*n_ctx_train =*/ 0, + /*param_filter =*/ llama_opt_param_filter_all, + /*param_filter_ud =*/ nullptr, + /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, + /*get_opt_pars_ud =*/ &optimizer_params, + }; + llama_opt_init(ctx.get(), model.get(), lopt_params); + + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_eval = ggml_opt_result_init(); + + for (int epoch = 0; epoch < 2; ++epoch) { + llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); + fprintf(stderr, "\n"); + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_eval); + } + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_eval); + + llama_model_save_to_file(model.get(), "finetuned-model.gguf"); + + llama_backend_free(); + + return 0; +} diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index eb5eab9de..da0c24b46 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -37,13 +37,16 @@ extern "C" { // ====== Dataset ====== GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( - int64_t ne_datapoint, // number of elements per datapoint - int64_t ne_label, // number of elements per label - int64_t ndata, // total number of datapoints/labels - int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); // get underlying tensors that store the data + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] @@ -56,13 +59,19 @@ extern "C" { struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + int64_t ibatch); // ====== Model / Context ====== enum ggml_opt_build_type { - GGML_OPT_BUILD_TYPE_FORWARD, - GGML_OPT_BUILD_TYPE_GRAD, - GGML_OPT_BUILD_TYPE_OPT, + GGML_OPT_BUILD_TYPE_FORWARD = 10, + GGML_OPT_BUILD_TYPE_GRAD = 20, + GGML_OPT_BUILD_TYPE_OPT = 30, }; // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss @@ -81,20 +90,22 @@ extern "C" { // userdata can be used to pass arbitrary data typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); - // returns the default optimizer params (constant) + // returns the default optimizer params (constant, hard-coded values) // userdata is not used GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + // casts userdata to ggml_opt_optimizer_params and returns it + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); + // parameters for initializing a new optimization context struct ggml_opt_params { ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs - struct ggml_context * ctx_compute; // created in user code, holds non-static tensors - - // the forward graph is defined by inputs and outputs - // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts - struct ggml_tensor * inputs; - struct ggml_tensor * outputs; + // by default the forward graph needs to be reconstructed for each eval + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically + struct ggml_context * ctx_compute; + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; enum ggml_opt_loss_type loss_type; enum ggml_opt_build_type build_type; @@ -107,12 +118,9 @@ extern "C" { // get parameters for an optimization context with defaults set where possible // parameters for which no sensible defaults exist are supplied as arguments to this function - GGML_API ggml_opt_params ggml_opt_default_params( - ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, - enum ggml_opt_loss_type loss_type); + GGML_API struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type); GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); @@ -121,6 +129,7 @@ extern "C" { GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); // get underlying tensors that store data + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against @@ -128,11 +137,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + // get the gradient accumulator for a node from the forward graph GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); // ====== Optimization Result ====== - GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API ggml_opt_result_t ggml_opt_result_init(void); GGML_API void ggml_opt_result_free(ggml_opt_result_t result); GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); @@ -144,11 +154,20 @@ extern "C" { // ====== Computation ====== - // do forward pass, increment result if not NULL - GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // if not using static graphs, this function must be called prior to ggml_opt_alloc + GGML_API void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs); - // do forward pass, increment result if not NULL, do backward pass - GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // allocate the next graph for evaluation, either forward or forward + backward + // must be called exactly once prior to calling ggml_opt_eval + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); + + // do forward pass, increment result if not NULL, do backward pass if allocated + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); // ############################################################################ // ## The high-level functions start here. They do not depend on any private ## @@ -200,9 +219,9 @@ extern "C" { // fit model defined by inputs and outputs to dataset GGML_API void ggml_opt_fit( ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs - ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs - ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] - ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used ggml_opt_dataset_t dataset, // dataset with data and optionally also labels enum ggml_opt_loss_type loss_type, // loss to minimize ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c518366d5..e91dedf14 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -768,7 +768,7 @@ extern "C" { // Tensor flags GGML_API void ggml_set_input(struct ggml_tensor * tensor); GGML_API void ggml_set_output(struct ggml_tensor * tensor); - GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); // @@ -938,7 +938,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride // concat a and b along dim // used in stable-diffusion @@ -2049,15 +2049,14 @@ extern "C" { GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( - struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) - struct ggml_context * ctx_compute, // context for gradient computation - struct ggml_cgraph * cgraph, - bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static + struct ggml_context * ctx, // context for gradient computation + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs); // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); - GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 6f69d895f..b30b4cb38 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - assert(node_backend_id != -1); // all nodes should be assigned by now + assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback // check if we should start a new split based on the sources of the current node bool need_new_split = false; diff --git a/ggml/src/ggml-cuda/acc.cu b/ggml/src/ggml-cuda/acc.cu index 96bfe1c9d..e084607c0 100644 --- a/ggml/src/ggml-cuda/acc.cu +++ b/ggml/src/ggml-cuda/acc.cu @@ -1,47 +1,61 @@ #include "acc.cuh" -static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset) { - const int i = blockDim.x * blockIdx.x + threadIdx.x; +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } -static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, const int offset, cudaStream_t stream) { - int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; - acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); } void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); - acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream); + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); + + acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream); } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index f9589080a..eb3d7cdba 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 7c3e24103..58d77578f 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -28,16 +28,19 @@ struct ggml_opt_dataset { }; struct ggml_opt_context { - ggml_backend_sched_t backend_sched = nullptr; - ggml_cgraph * allocated_graph = nullptr; - ggml_cgraph * allocated_graph_copy = nullptr; - struct ggml_context * ctx_static = nullptr; - struct ggml_context * ctx_static_cpu = nullptr; - struct ggml_context * ctx_compute = nullptr; - struct ggml_context * ctx_copy = nullptr; - ggml_backend_buffer_t buf_static = nullptr; - ggml_backend_buffer_t buf_static_cpu = nullptr; - std::mt19937 rng; + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_cpu = nullptr; + std::mt19937 rng; + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + enum ggml_opt_build_type build_type_alloc; struct ggml_tensor * inputs = nullptr; struct ggml_tensor * outputs = nullptr; @@ -50,6 +53,11 @@ struct ggml_opt_context { struct ggml_cgraph * gf = nullptr; struct ggml_cgraph * gb_grad = nullptr; struct ggml_cgraph * gb_opt = nullptr; + bool static_graphs = false; + bool eval_ready = false; + std::vector grad_accs; + std::vector grad_m; + std::vector grad_v; int64_t iter = 1; int32_t opt_period = 1; @@ -73,7 +81,13 @@ struct ggml_opt_result { // ====== Dataset ====== -ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { +ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, + enum ggml_type type_label, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ndata, + int64_t ndata_shard) { GGML_ASSERT(ne_datapoint > 0); GGML_ASSERT(ne_label >= 0); GGML_ASSERT(ndata > 0); @@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, result->ctx = ggml_init(params); } - result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; if (ne_label > 0) { - result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; } else { result->labels = nullptr; @@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { delete dataset; } +int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) { + return dataset->ndata; +} + struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { return dataset->data; } @@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT( data_batch->type == dataset->data->type); + GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type); const size_t nb_data_batch = ggml_nbytes(data_batch); GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); @@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * } } +void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } +} + // ====== Model / Context ====== struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { @@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us return result; } +struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { + return *((struct ggml_opt_optimizer_params *) userdata); +} + struct ggml_opt_params ggml_opt_default_params( ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, enum ggml_opt_loss_type loss_type) { return { /*backend_sched =*/ backend_sched, - /*ctx_compute =*/ ctx_compute, - /*inputs =*/ inputs, - /*logits =*/ outputs, + /*ctx_compute =*/ nullptr, + /*inputs =*/ nullptr, + /*logits =*/ nullptr, /*loss_type =*/ loss_type, /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, /*opt_period =*/ 1, @@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { return dst; } -static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { - GGML_ASSERT(graph); - if (opt_ctx->allocated_graph == graph) { - return; - } +static void ggml_opt_build(ggml_opt_context_t opt_ctx) { + GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); + GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); - ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && + !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - ggml_free(opt_ctx->ctx_copy); - opt_ctx->ctx_copy = ggml_init(params); - } - - opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); - - ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); - opt_ctx->allocated_graph = graph; -} - -ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { - ggml_opt_context_t result = new struct ggml_opt_context; - result->backend_sched = params.backend_sched; - result->ctx_compute = params.ctx_compute; - result->inputs = params.inputs; - result->outputs = params.outputs; - result->opt_period = params.opt_period; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; - - GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); - GGML_ASSERT(result->opt_period >= 1); - - const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || - (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); - - ggml_set_input(result->inputs); - ggml_set_output(result->outputs); - - result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. - ggml_build_forward_expand(result->gf, result->outputs); + ggml_set_input(opt_ctx->inputs); + ggml_set_output(opt_ctx->outputs); int n_param = 0; - for (int i = 0; i < result->gf->n_nodes; ++i) { - if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) { + const struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { n_param++; } + GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented"); } - { + if (!opt_ctx->ctx_static) { // The static context is used for: - // - gradients (1 tensor per param if using gradient accumulation) + // - gradients (1 per loss, 1 tensor per param if using gradient accumulation) // - optimizer momenta (2 tensors per param) - // - labels - // - loss + its gradient (up to 5 tensors) - // - pred - // - ncorrect (2 tensors). - const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); - const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + // - labels (if using static graphs) + // - loss (if using static graphs, up to 5 tensors) + // - pred (if using static graphs) + // - ncorrect (if using static graphs, 2 tensors). + constexpr size_t n_loss = 1; + const size_t tensors_per_param = (accumulate ? 1 : 0) + + (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; + const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static = ggml_init(params); + opt_ctx->ctx_static = ggml_init(params); } + GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc); + { - // The static cpu context is used for: - // - optimizer parameters (1 for the entire context) + // The cpu context is allocated statically if using static graphs, dynamically otherwise. + // It is used for: + // - optimizer parameters (1 shared for all optimizer invocations) const size_t size_meta = 1 * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static_cpu = ggml_init(params); + ggml_free(opt_ctx->ctx_cpu); + opt_ctx->ctx_cpu = ggml_init(params); + + ggml_backend_buffer_free(opt_ctx->buf_cpu); + opt_ctx->buf_cpu = nullptr; } + struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute; - switch (params.loss_type) { + switch (opt_ctx->loss_type) { case GGML_OPT_LOSS_TYPE_MEAN: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean"); - result->loss_per_datapoint = true; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean"); + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_SUM: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - result->loss_per_datapoint = false; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + opt_ctx->loss_per_datapoint = false; break; } case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_cross_entropy"); - if (result->opt_period > 1) { - result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); - ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled"); } - result->loss_per_datapoint = true; + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_error"); - result->loss = ggml_sqr(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_squared_error"); - result->loss = ggml_sum(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_sum_squared_error"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean_squared_error"); - result->loss_per_datapoint = true; + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_error"); + opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_squared_error"); + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean_squared_error"); + opt_ctx->loss_per_datapoint = true; break; } } - ggml_set_output(result->loss); - ggml_set_loss(result->loss); - ggml_build_forward_expand(result->gf, result->loss); + ggml_set_output(opt_ctx->loss); + ggml_set_loss(opt_ctx->loss); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); - result->pred = ggml_argmax(result->ctx_static, result->outputs); - ggml_set_name(result->pred, "pred"); - ggml_set_output(result->pred); - ggml_build_forward_expand(result->gf, result->pred); + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->pred, "pred"); + ggml_set_output(opt_ctx->pred); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); - if (result->labels) { - result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); - ggml_set_name(result->ncorrect, "ncorrect"); - ggml_set_output(result->ncorrect); - ggml_build_forward_expand(result->gf, result->ncorrect); - } else { - result->ncorrect = nullptr; + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); } - if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - return result; + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + return; } - // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. - result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); - ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + if (opt_ctx->grad_accs.empty()) { + GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD); - if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - ggml_graph_reset(result->gb_grad); - return result; - } + const int n_nodes = opt_ctx->gf->n_nodes; + opt_ctx->grad_accs.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_accs[i] = nullptr; + } + } - GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); - - // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. - result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); - - result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); - ggml_set_input(result->adamw_params); - ggml_set_name(result->adamw_params, "adamw_params"); - - for (int i = result->gf->n_nodes-1; i >= 0; --i) { - struct ggml_tensor * node = result->gb_opt->nodes[i]; - struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); - ggml_build_forward_expand(result->gb_opt, opt_step); + if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + opt_ctx->grad_m.resize(n_nodes); + opt_ctx->grad_v.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_m[i] = nullptr; + opt_ctx->grad_v[i] = nullptr; + } + } } } - result->buf_static = ggml_backend_alloc_ctx_tensors( - result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); + ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); - result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_grad); + } - ggml_graph_reset(result->gb_opt); + GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); + + opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); + ggml_set_input(opt_ctx->adamw_params); + ggml_set_name(opt_ctx->adamw_params, "adamw_params"); + + for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); + + if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + struct ggml_tensor * m = opt_ctx->grad_m[i]; + struct ggml_tensor * v = opt_ctx->grad_v[i]; + struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); + + ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); + ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); + ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); + + ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); + } + } + + if (!opt_ctx->buf_static) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_opt); + } + + opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type()); +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->loss_type = params.loss_type; + result->build_type = params.build_type; + result->build_type_alloc = params.build_type; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->opt_period >= 1); + + result->static_graphs = result->ctx_compute; + + if (!result->static_graphs) { + GGML_ASSERT(!result->inputs); + GGML_ASSERT(!result->outputs); + return result; + } + + GGML_ASSERT(result->inputs); + GGML_ASSERT(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + ggml_opt_build(result); return result; } @@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { return; } ggml_backend_buffer_free(opt_ctx->buf_static); - ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); - ggml_free(opt_ctx->ctx_static_cpu); + ggml_free(opt_ctx->ctx_cpu); delete opt_ctx; } @@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl // ====== Computation ====== -static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { - if (graph != opt_ctx->gf) { +void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs) { + GGML_ASSERT(!opt_ctx->static_graphs); + opt_ctx->ctx_compute = ctx_compute; + opt_ctx->gf = gf; + opt_ctx->inputs = inputs; + opt_ctx->outputs = outputs; +} + +void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { + GGML_ASSERT(!opt_ctx->eval_ready); + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) { + ggml_graph_reset(opt_ctx->gb_grad); + } + if (backward) { + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD; + } else { + opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD; + } + + if (!opt_ctx->static_graphs) { + ggml_opt_build(opt_ctx); + } + + struct ggml_cgraph * graph = nullptr; + switch (opt_ctx->build_type) { + case GGML_OPT_BUILD_TYPE_FORWARD: { + graph = opt_ctx->gf; + } break; + case GGML_OPT_BUILD_TYPE_GRAD: { + graph = opt_ctx->gb_grad; + } break; + case GGML_OPT_BUILD_TYPE_OPT: { + graph = opt_ctx->gb_opt; + } break; + } + GGML_ASSERT(graph); + + if (opt_ctx->allocated_graph == graph) { + opt_ctx->eval_ready = true; + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + if (opt_ctx->static_graphs) { + ggml_init_params params = { + /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + } else { + opt_ctx->allocated_graph_copy = graph; + } + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; + + opt_ctx->eval_ready = true; +} + +void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { + GGML_ASSERT(opt_ctx->eval_ready); + if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); @@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, adamw_par_data[6] = beta2h; } - ggml_opt_alloc_graph(opt_ctx, graph); ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + + if (!opt_ctx->static_graphs) { + opt_ctx->gf = nullptr; + opt_ctx->gb_grad = nullptr; + opt_ctx->gb_opt = nullptr; + opt_ctx->allocated_graph = nullptr; + opt_ctx->allocated_graph_copy = nullptr; + } + + opt_ctx->eval_ready = false; if (!result) { return; @@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); result->loss.push_back(loss); - GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); - std::vector pred(ndata); - ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); - result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + if (opt_ctx->pred) { + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + } - if (!opt_ctx->labels || result->ncorrect < 0) { + if (!opt_ctx->ncorrect || result->ncorrect < 0) { result->ncorrect = -1; return; } @@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, result->ncorrect += ncorrect; } -void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); -} - -void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - if (opt_ctx->opt_period == 1) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - return; - } - - const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; - if (opt_i_next == 0) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - ggml_opt_reset(opt_ctx, /*optimizer =*/ false); - } else { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); - } - opt_ctx->opt_i = opt_i_next; -} - // ====== High-Level Functions ====== void ggml_opt_epoch( @@ -700,16 +860,18 @@ void ggml_opt_epoch( int64_t ibatch = 0; int64_t t_loop_start = ggml_time_us(); for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ true); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward_backward(opt_ctx, result_train); + ggml_opt_eval(opt_ctx, result_train); if (callback_train) { callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); } } t_loop_start = ggml_time_us(); for (; ibatch < nbatches; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ false); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward(opt_ctx, result_eval); + ggml_opt_eval(opt_ctx, result_eval); if (callback_eval) { callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); } @@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar( int64_t t_start_us) { fprintf(stderr, "%s[", train ? "train: " : "val: "); - constexpr int64_t bar_length = 25; + // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels. + constexpr int64_t bar_length = 8; + const int64_t ibatch8 = 8 * ibatch; for (int64_t j = 0; j < bar_length; ++j) { - const int64_t ibatch_j = ibatch_max * j/bar_length; - if (ibatch_j < ibatch) { - fprintf(stderr, "="); - } else if (ibatch_max * (j - 1)/bar_length < ibatch) { - fprintf(stderr, ">"); + if (ibatch_max * (8*j + 8) / bar_length < ibatch8) { + fprintf(stderr, "\u2588"); // full block + } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) { + fprintf(stderr, "\u2589"); // 7/8 filled + } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) { + fprintf(stderr, "\u258A"); // 6/8 filled + } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) { + fprintf(stderr, "\u258B"); // 5/8 filled + } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) { + fprintf(stderr, "\u258C"); // 4/8 filled + } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) { + fprintf(stderr, "\u258D"); // 3/8 filled + } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) { + fprintf(stderr, "\u258E"); // 2/8 filled + } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) { + fprintf(stderr, "\u258F"); // 1/8 filled } else { fprintf(stderr, " "); } @@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar( const int64_t t_eta_m = t_eta_s / 60; t_eta_s -= t_eta_m * 60; - fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " - "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r", idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); if (ibatch == ibatch_max) { @@ -806,7 +981,10 @@ void ggml_opt_fit( int64_t epoch = 1; - ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type); + params.ctx_compute = ctx_compute; + params.inputs = inputs; + params.outputs = outputs; params.opt_period = opt_period; params.get_opt_pars = get_opt_pars; params.get_opt_pars_ud = &epoch; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bc673292b..8a6546240 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5499,7 +5499,7 @@ static void ggml_compute_backward( // tensor = src0 * 1 + src1 * 0 if (src0_needs_grads) { // dsrc0 = dtensor * 1 - ggml_add_or_set(ctx, cgraph, isrc0, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0)); } if (src1_needs_grads) { // dsrc1 = dtensor * 0 -> noop @@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * } void ggml_build_backward_expand( - struct ggml_context * ctx_static, - struct ggml_context * ctx_compute, - struct ggml_cgraph * cgraph, - bool accumulate) { + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs) { GGML_ASSERT(cgraph->n_nodes > 0); GGML_ASSERT(cgraph->grads); GGML_ASSERT(cgraph->grad_accs); @@ -5856,21 +5855,24 @@ void ggml_build_backward_expand( GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); - GGML_ASSERT(igrad != GGML_HASHSET_FULL); - GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); - if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { - cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node); - cgraph->grads[igrad] = cgraph->grad_accs[igrad]; - ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); + const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(ihash != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash)); + if (grad_accs && grad_accs[i]) { + cgraph->grad_accs[ihash] = grad_accs[i]; + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } else if (node->flags & GGML_TENSOR_FLAG_LOSS) { + // loss tensors always need a gradient accumulator + cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; } - grads_needed[igrad] = true; + grads_needed[ihash] = true; } for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); + ggml_compute_backward(ctx, cgraph, i, grads_needed); } free(grads_needed); @@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } } -struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); ggml_graph_cpy(cgraph, result); return result; } @@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { } void ggml_graph_reset(struct ggml_cgraph * cgraph) { + if (!cgraph) { + return; + } GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } -void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { - GGML_UNUSED(ctx); // TODO: remove this parameter +void ggml_set_param(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_NONE); tensor->flags |= GGML_TENSOR_FLAG_PARAM; } diff --git a/include/llama.h b/include/llama.h index 6c6d377f8..abedebdb7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -4,6 +4,7 @@ #include "ggml.h" #include "ggml-cpu.h" #include "ggml-backend.h" +#include "ggml-opt.h" #include #include @@ -445,6 +446,10 @@ extern "C" { size_t n_paths, struct llama_model_params params); + LLAMA_API void llama_model_save_to_file( + const struct llama_model * model, + const char * path_model); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), "use llama_model_free instead"); @@ -1433,6 +1438,37 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + // + // training + // + + // function that returns whether or not a given tensor contains trainable parameters + typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); + + // always returns true + LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); + + struct llama_opt_params { + uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 + + llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters + void * param_filter_ud; // userdata for determining which tensors contain trainable parameters + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); + + LLAMA_API void llama_opt_epoch( + struct llama_context * lctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1cd316b03..d4bf37b1c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp + llama-model-saver.cpp llama-model.cpp llama-quant.cpp llama-sampling.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0cb6ebc9f..62246c10d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -359,7 +359,9 @@ llama_context::llama_context( } } -llama_context::~llama_context() = default; +llama_context::~llama_context() { + ggml_opt_free(opt_ctx); +} void llama_context::synchronize() { ggml_backend_sched_synchronize(sched.get()); @@ -1846,6 +1848,215 @@ void llama_context::perf_reset() { t_p_eval_us = n_p_eval = 0; } +// +// training +// + +static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { + if (!tensor || tensor->type != GGML_TYPE_F32) { + return; + } + if (!param_filter(tensor, userdata)) { + return; + } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + return; // FIXME + } + if (strcmp(tensor->name, "rope_freqs.weight") == 0) { + return; // FIXME + } + ggml_set_param(tensor); +} + +void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { + GGML_ASSERT(!opt_ctx); + model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); + const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); + GGML_ASSERT(n_batch % n_ubatch == 0); + + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + opt_params.opt_period = n_batch / n_ubatch; + opt_params.get_opt_pars = lopt_params.get_opt_pars; + opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; + + opt_ctx = ggml_opt_init(opt_params); + + llama_opt_param_filter param_filter = lopt_params.param_filter; + void * param_filter_ud = lopt_params.param_filter_ud; + + //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME + llama_set_param(model->type_embd, param_filter, param_filter_ud); + llama_set_param(model->pos_embd, param_filter, param_filter_ud); + llama_set_param(model->tok_norm, param_filter, param_filter_ud); + llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm, param_filter, param_filter_ud); + llama_set_param(model->output_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output, param_filter, param_filter_ud); + llama_set_param(model->output_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); + llama_set_param(model->cls, param_filter, param_filter_ud); + llama_set_param(model->cls_b, param_filter, param_filter_ud); + llama_set_param(model->cls_out, param_filter, param_filter_ud); + llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + + for (struct llama_layer & layer : model->layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); + } + } +} + +void llama_context::opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start) { + GGML_ASSERT(opt_ctx); + const uint32_t n_ctx = llama_model_n_ctx_train(&model); + const uint32_t n_batch = std::min(this->n_batch(), n_ctx); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->clear(); + llama_kv_cache_guard kv_guard(kv_self); + + for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { + batch.n_tokens = n_batch; + for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { + batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; + batch.pos [pos_batch] = pos_ctx + pos_batch; + batch.n_seq_id[pos_batch] = 1; + batch.seq_id [pos_batch][0] = 0; + batch.logits [pos_batch] = true; + } + + const auto n_tokens_all = batch.n_tokens; + + n_queued_tokens += n_tokens_all; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = n_tokens_all; + + llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true); + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + GGML_ABORT("TODO: handle this error"); + }; + + for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) { + llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); + + n_outputs = ubatch.n_tokens; + + // TODO: not sure if this is needed + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + + GGML_ABORT("TODO: handle this error"); + } + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); + + struct ggml_context * ctx_compute_opt; + { + const size_t size_gf = ggml_graph_size(gf); + const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute_opt = ggml_init(params); + } + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_alloc(opt_ctx, train); + res->set_inputs(&ubatch); + { + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + GGML_ASSERT(labels->ne[1] == n_ubatch); + ggml_set_zero(labels); + const float onef = 1.0f; + for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { + const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; + GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); + ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); + } + } + ggml_opt_eval(opt_ctx, result); + if (callback) { + callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); + } + ggml_free(ctx_compute_opt); + } + } + + kv_guard.commit(); +} + +void llama_context::opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + const uint32_t n_ctx = this->n_ctx(); + const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); + const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); + const int64_t ndata = ggml_opt_dataset_ndata(dataset); + + GGML_ASSERT(idata_split >= 0); + GGML_ASSERT(idata_split <= ndata); + + const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; + + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + std::vector tokens(n_ctx); + std::vector labels_sparse(n_ctx); + + int64_t idata = 0; + + int64_t t_loop_start = ggml_time_us(); + int64_t ndata_in_loop = idata_split*ubatch_per_ctx; + for (; idata < idata_split; ++idata) { + constexpr bool train = true; + const int64_t idata_in_loop = idata*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, + callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + t_loop_start = ggml_time_us(); + ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; + for (; idata < ndata; ++idata) { + constexpr bool train = false; + const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, + callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + llama_batch_free(batch); +} + // // interface implementation // @@ -2464,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) { void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } + +// +// training +// + +bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { + GGML_UNUSED(tensor); + GGML_UNUSED(userdata); + return true; +} + +void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { + ctx->opt_init(model, lopt_params); +} + +void llama_opt_epoch( + struct llama_context * ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + ctx->opt_epoch( + dataset, + result_train, + result_eval, + idata_split, + callback_train, + callback_eval); +} diff --git a/src/llama-context.h b/src/llama-context.h index 5a080e67f..c0ceacb10 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -7,6 +7,7 @@ #include "llama-adapter.h" #include "ggml-cpp.h" +#include "ggml-opt.h" #include #include @@ -133,6 +134,32 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); + // + // training + // + + void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + + void opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + void opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start); + private: // // output @@ -212,6 +239,9 @@ private: ggml_context_ptr ctx_compute; + // training + ggml_opt_context_t opt_ctx = nullptr; + ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a8bb83cc5..b0e3f6359 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -971,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); //cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5b404366d..832a8c09f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -298,6 +298,7 @@ class llm_graph_result_i { public: virtual ~llm_graph_result_i() = default; + virtual ggml_tensor * get_tokens() = 0; virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd_pooled() = 0; @@ -312,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i { public: virtual ~llm_graph_result() = default; + ggml_tensor * get_tokens() override { return t_tokens; } ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } @@ -328,6 +330,7 @@ public: } // important graph nodes + ggml_tensor * t_tokens = nullptr; ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 1c8bce385..4cce51668 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -301,12 +301,12 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); } result.resize(arr_info.length); @@ -330,12 +330,12 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(meta.get(), kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); } if (arr_info.length > N_MAX) { diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp new file mode 100644 index 000000000..a70b98923 --- /dev/null +++ b/src/llama-model-saver.cpp @@ -0,0 +1,281 @@ +#include "llama-model-saver.h" + +#include "gguf.h" + +#include "llama.h" +#include "llama-hparams.h" +#include "llama-model.h" +#include "llama-vocab.h" + +#include + +llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { + gguf_ctx = gguf_init_empty(); +} + +llama_model_saver::~llama_model_saver() { + gguf_free(gguf_ctx); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { + gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) { + gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const float value) { + gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const bool value) { + gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const char * value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value); +} + +[[noreturn]] +void llama_model_saver::add_kv(const enum llm_kv key, const char value) { + GGML_UNUSED(key); + GGML_UNUSED(value); + GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile +} + +template +void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { + const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(n_values <= value.size()); + + if (n_values == 0) { + return; + } + + if (per_layer) { + bool all_values_the_same = true; + for (size_t i = 1; i < n_values; ++i) { + if (value[i] != value[0]) { + all_values_the_same = false; + break; + } + } + if (all_values_the_same) { + add_kv(key, value[0]); + return; + } + } + + if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast(value.data())); + } else { + GGML_ABORT("fatal error"); + } +} + +void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { + std::vector tmp(value.size()); + for (size_t i = 0; i < value.size(); ++i) { + tmp[i] = value[i].c_str(); + } + gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size()); +} + +void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { + if (!tensor) { + return; + } + if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { + GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + return; + } + gguf_add_tensor(gguf_ctx, tensor); +} + +void llama_model_saver::add_kv_from_model() { + const llama_hparams & hparams = model.hparams; + const llama_vocab & vocab = model.vocab; + + const int32_t n_vocab = vocab.n_tokens(); + std::vector tokens(n_vocab); + std::vector scores(n_vocab); + std::vector token_types(n_vocab); + + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } + } + + // add_kv(LLM_KV_GENERAL_TYPE, ???); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); + // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + add_kv(LLM_KV_GENERAL_NAME, model.name); + // add_kv(LLM_KV_GENERAL_AUTHOR, ???); + // add_kv(LLM_KV_GENERAL_VERSION, ???); + // add_kv(LLM_KV_GENERAL_URL, ???); + // add_kv(LLM_KV_GENERAL_DESCRIPTION, ???); + // add_kv(LLM_KV_GENERAL_LICENSE, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_URL, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???); + + add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); + add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); + add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); + add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); + add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); + add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); + add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); + add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + + add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); + add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); + add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; + + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name + add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); + add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); + add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); + add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); + add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + // TODO: implement split file support + // add_kv(LLM_KV_SPLIT_NO, ???); + // add_kv(LLM_KV_SPLIT_COUNT, ???); + // add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???); + + add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + + add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); + add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre()); + add_kv(LLM_KV_TOKENIZER_LIST, tokens); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types()); + add_kv(LLM_KV_TOKENIZER_SCORES, scores); + add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges()); + // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though + add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos())); + add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos())); + add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot())); + add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom())); + add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk())); + add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep())); + add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad())); + // add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated + // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); + add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); + add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); + add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); + add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); + add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); + // add_kv(LLM_KV_TOKENIZER_HF_JSON, ???); + // add_kv(LLM_KV_TOKENIZER_RWKV, ???); + add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre())); + add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf())); + add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid())); + add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad())); + add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep())); + add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep())); + + // TODO: implement LoRA support + // add_kv(LLM_KV_ADAPTER_TYPE, ???); + // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + + // deprecated + // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); +} + +void llama_model_saver::add_tensors_from_model() { + if (std::string(model.output->name) != std::string(model.tok_embd->name)) { + add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + } + add_tensor(model.type_embd); + add_tensor(model.pos_embd); + add_tensor(model.tok_norm); + add_tensor(model.tok_norm_b); + add_tensor(model.output_norm); + add_tensor(model.output_norm_b); + add_tensor(model.output); + add_tensor(model.output_b); + add_tensor(model.output_norm_enc); + add_tensor(model.cls); + add_tensor(model.cls_b); + add_tensor(model.cls_out); + add_tensor(model.cls_out_b); + + for (const struct llama_layer & layer : model.layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + add_tensor(reinterpret_cast(&layer)[i]); + } + } +} + +void llama_model_saver::save(const std::string & path_model) { + gguf_write_to_file(gguf_ctx, path_model.c_str(), false); +} + diff --git a/src/llama-model-saver.h b/src/llama-model-saver.h new file mode 100644 index 000000000..a5a434c30 --- /dev/null +++ b/src/llama-model-saver.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" + +#include + +struct llama_model_saver { + struct gguf_context * gguf_ctx = nullptr; + const struct llama_model & model; + const struct LLM_KV llm_kv; + + llama_model_saver(const struct llama_model & model); + ~llama_model_saver(); + + void add_kv(enum llm_kv key, uint32_t value); + void add_kv(enum llm_kv key, int32_t value); + void add_kv(enum llm_kv key, float value); + void add_kv(enum llm_kv key, bool value); + void add_kv(enum llm_kv key, const char * value); + + [[noreturn]] + void add_kv(enum llm_kv key, char value); // needed to make the template below compile + + template + void add_kv(enum llm_kv key, const Container & value, bool per_layer = false); + + void add_kv(enum llm_kv key, const std::vector & value); + + void add_tensor(const struct ggml_tensor * tensor); + + void add_kv_from_model(); + + void add_tensors_from_model(); + + void save(const std::string & path_model); +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 21b12339a..3a4e72a36 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -117,6 +117,10 @@ static const std::map LLAMA_ROPE_SCALING_ { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) { + return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type); +} + static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { if (kv.second == name) { @@ -4264,7 +4268,7 @@ uint64_t llama_model::n_elements() const { } void llama_model::print_info() const { - const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); auto print_f = [](const std::function & f, uint32_t n) { bool is_var = false; @@ -4325,7 +4329,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); diff --git a/src/llama-model.h b/src/llama-model.h index 815fa11eb..6bdec263b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -96,6 +96,8 @@ enum llm_type { LLM_TYPE_235B_A22B, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 7dc542276..820d5128e 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - // mmap consistently increases speed Linux, and also increases speed on Windows with + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) constexpr bool use_mmap = true; @@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_kv_override * kv_overrides = nullptr; if (params->kv_overrides) { - auto v = (std::vector*)params->kv_overrides; + auto * v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e6b96ecc8..9389ca805 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1,5 +1,7 @@ #include "llama-vocab.h" +#include "ggml.h" +#include "gguf.h" #include "llama-impl.h" #include "llama-model-loader.h" @@ -1234,6 +1236,9 @@ struct fragment_buffer_variant { struct llama_vocab::impl { uint32_t n_token_types = 0; // for BERT-style token types + std::string tokenizer_model; + std::string tokenizer_pre; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -1369,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // determine vocab type { - std::string tokenizer_model; - std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); @@ -1466,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #ifdef IS_BIG_ENDIAN @@ -2789,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); } +std::string llama_vocab::get_tokenizer_model() const { + return pimpl->tokenizer_model; +} + +std::string llama_vocab::get_tokenizer_pre() const { + return pimpl->tokenizer_pre; +} + enum llama_vocab_type llama_vocab::get_type() const { return pimpl->type; } @@ -3011,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string return it->second; } +std::vector llama_vocab::get_bpe_merges() const { + std::vector result(pimpl->bpe_ranks.size()); + + for (const auto & pair : pimpl->bpe_ranks) { + result[pair.second] = pair.first.first + " " + pair.first.second; + } + + return result; +} + +std::vector llama_vocab::get_precompiled_charsmap() const { + return pimpl->precompiled_charsmap; +} + int32_t llama_vocab::tokenize( const char * text, int32_t text_len, diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 5ce355214..daa6cf308 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -21,6 +21,9 @@ struct llama_vocab { void load(llama_model_loader & ml, const LLM_KV & kv); + std::string get_tokenizer_model() const; + std::string get_tokenizer_pre() const; + enum llama_vocab_type get_type() const; enum llama_vocab_pre_type get_pre_type() const; @@ -80,6 +83,9 @@ struct llama_vocab { int max_token_len() const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; + std::vector get_bpe_merges() const; + + std::vector get_precompiled_charsmap() const; int32_t tokenize( const char * text, diff --git a/src/llama.cpp b/src/llama.cpp index d5164720b..9fdddf7b0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4,6 +4,7 @@ #include "llama-mmap.h" #include "llama-vocab.h" #include "llama-model-loader.h" +#include "llama-model-saver.h" #include "llama-model.h" #include "ggml.h" @@ -253,6 +254,13 @@ struct llama_model * llama_model_load_from_splits( return llama_model_load_from_file_impl(splits.front(), splits, params); } +void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { + llama_model_saver ms(*model); + ms.add_kv_from_model(); + ms.add_tensors_from_model(); + ms.save(path_model); +} + // // chat templates // @@ -338,3 +346,4 @@ const char * llama_print_system_info(void) { return s.c_str(); } + diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9ec24d9f2..543db9340 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -823,7 +823,7 @@ struct test_case { ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false); + ggml_build_backward_expand(ctx.get(), gb, nullptr); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { @@ -1026,7 +1026,7 @@ struct test_example : public test_case { // Step 3: return the output tensor. return out; } - // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a) + // In order to also check the gradients for your op, add calls like ggml_set_param(a) // immediately after you create the tensors. // This is optional and only makes sense if a backward pass has actually been implemented for the new op. }; @@ -1058,7 +1058,7 @@ struct test_unary : public test_case { auto ne = ne_a; ne[0] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); if (grad_supported) { - ggml_set_param(ctx, a); + ggml_set_param(a); } ggml_set_name(a, "a"); @@ -1067,7 +1067,7 @@ struct test_unary : public test_case { } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); if (grad_supported) { - ggml_set_param(ctx, a); + ggml_set_param(a); } ggml_set_name(a, "a"); } @@ -1133,7 +1133,7 @@ struct test_get_rows : public test_case { const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows); if (grad_supported) { - ggml_set_param(ctx, in); + ggml_set_param(in); // rows is a constant input -> no gradients } @@ -1322,7 +1322,7 @@ struct test_repeat : public test_case { ggml_set_name(target, "target"); ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, src); + ggml_set_param(src); ggml_set_name(src, "src"); ggml_tensor * out = ggml_repeat(ctx, src, target); @@ -1406,7 +1406,7 @@ struct test_dup : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, src); + ggml_set_param(src); ggml_set_name(src, "src"); if (_use_permute) { @@ -1442,7 +1442,7 @@ struct test_set : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); - ggml_set_param(ctx, src); + ggml_set_param(src); ggml_set_name(src, "src"); auto ne_dst = ne; @@ -1450,7 +1450,7 @@ struct test_set : public test_case { ne_dst[i] *= 2; } ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); - ggml_set_param(ctx, dst); + ggml_set_param(dst); ggml_set_name(dst, "dst"); size_t offset = 0; @@ -1498,7 +1498,7 @@ struct test_cpy : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); - ggml_set_param(ctx, src); + ggml_set_param(src); ggml_set_name(src, "src"); if (_src_use_permute) { @@ -1536,7 +1536,7 @@ struct test_cont : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, src); + ggml_set_param(src); ggml_set_name(src, "src"); src = ggml_transpose(ctx, src); @@ -1583,8 +1583,8 @@ struct test_bin_bcast : public test_case { // The backward pass supports broadcasting only for GGML_ADD: const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b); if (grad_supported) { - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + ggml_set_param(a); + ggml_set_param(b); } ggml_tensor * out = op(ctx, a, b); @@ -1632,11 +1632,11 @@ struct test_add1 : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1); - // ggml_set_param(ctx, b); // TODO: implement + // ggml_set_param(b); // TODO: implement ggml_set_name(b, "b"); ggml_tensor * out = ggml_add1(ctx, a, b); @@ -1667,7 +1667,7 @@ struct test_scale : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_scale(ctx, a, scale); @@ -1762,7 +1762,7 @@ struct test_rms_norm : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); if (v) { @@ -2028,9 +2028,9 @@ struct test_mul_mat : public test_case { b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); if (!ggml_is_quantized(type_a)) { if (bs[1] == 1 && nr[1] == 1) { - ggml_set_param(ctx, a); + ggml_set_param(a); } - ggml_set_param(ctx, b); + ggml_set_param(b); } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -2040,22 +2040,29 @@ struct test_mul_mat : public test_case { ggml_set_name(a, "a_permuted"); ggml_set_name(b, "b_permuted"); } else { - if (v) { a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(a); + } + ggml_set_param(b); + } + a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0); b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0); } else { a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - } - if (!ggml_is_quantized(type_a)) { - if (bs[1] == 1 && nr[1] == 1) { - ggml_set_param(ctx, a); + + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(a); + } + ggml_set_param(b); } - ggml_set_param(ctx, b); } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -2204,7 +2211,7 @@ struct test_sqr : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_sqr(ctx, a); @@ -2233,7 +2240,7 @@ struct test_sqrt : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_sqrt(ctx, a); @@ -2273,7 +2280,7 @@ struct test_log : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_log(ctx, a); @@ -2309,7 +2316,7 @@ struct test_sin : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_sin(ctx, a); @@ -2352,7 +2359,7 @@ struct test_cos : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_cos(ctx, a); @@ -2432,7 +2439,7 @@ struct test_diag_mask_inf : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past); @@ -2471,7 +2478,7 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * mask = nullptr; @@ -2553,7 +2560,7 @@ struct test_rope : public test_case { auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); if (forward) { - ggml_set_param(ctx, a); + ggml_set_param(a); } ggml_set_name(a, "a"); @@ -2562,7 +2569,7 @@ struct test_rope : public test_case { } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); if (forward) { - ggml_set_param(ctx, a); + ggml_set_param(a); } ggml_set_name(a, "a"); } @@ -2676,7 +2683,7 @@ struct test_pool2d : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); - ggml_set_param(ctx, input); + ggml_set_param(input); ggml_set_name(input, "input"); ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1); @@ -2752,7 +2759,7 @@ struct test_im2col : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); - ggml_set_param(ctx, input); + ggml_set_param(input); ggml_set_name(input, "input"); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); @@ -2929,7 +2936,7 @@ struct test_sum : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_sum(ctx, a); @@ -2958,7 +2965,7 @@ struct test_sum_rows : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_sum_rows(ctx, a); @@ -2983,7 +2990,7 @@ struct test_mean : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * out = ggml_mean(ctx, a); @@ -3129,11 +3136,11 @@ struct test_acc : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); - ggml_set_param(ctx, a); + ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); - ggml_set_param(ctx, b); + ggml_set_param(b); ggml_set_name(b, "b"); ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]); @@ -3370,7 +3377,7 @@ struct test_cross_entropy_loss : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_set_param(ctx, logits); + ggml_set_param(logits); ggml_set_name(logits, "logits"); ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -3452,7 +3459,7 @@ struct test_opt_step_adamw : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); - ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not. + ggml_set_param(a); // Despite tensor a having gradients the output tensor will not. ggml_set_name(a, "a"); ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 1bc160511..558f87721 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -57,7 +57,8 @@ static helper_ctx_data helper_get_ctx_data( enum ggml_opt_loss_type loss_type = GGML_OPT_LOSS_TYPE_SUM) { std::vector datasets(ndata); for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) { - ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init( + GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard); float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); @@ -74,7 +75,8 @@ static helper_ctx_data helper_get_ctx_data( datasets[ndata_shard-1] = dataset; } - ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1); + ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init( + GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1); float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised)); @@ -113,7 +115,7 @@ static helper_ctx_data helper_get_ctx_data( struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); ggml_set_name(weights, "weights"); - ggml_set_param(ctx_static, weights); + ggml_set_param(weights); struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights); @@ -127,8 +129,11 @@ static helper_ctx_data helper_get_ctx_data( GGML_ASSERT(nbatch_logical % nbatch_physical == 0); const int32_t opt_period = nbatch_logical / nbatch_physical; - struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); - opt_params.opt_period = opt_period; + struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type); + opt_params.ctx_compute = ctx_compute; + opt_params.inputs = inputs; + opt_params.outputs = outputs; + opt_params.opt_period = opt_period; if (!optimizer_defaults) { opt_params.get_opt_pars = helper_get_test_opt_pars; } @@ -264,8 +269,9 @@ static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_ba for (int idata = 0; idata < ndata; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); - ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float)); } @@ -334,8 +340,9 @@ static std::pair test_forward_backward( } else { for (int idata = 0; idata < ndata; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false); ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); - ggml_opt_forward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); } } @@ -367,7 +374,8 @@ static std::pair test_forward_backward( float w0; ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float)); for (int i = 0; i < 10; ++i) { - ggml_opt_forward_backward(cd.opt_ctx, nullptr); + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); + ggml_opt_eval(cd.opt_ctx, cd.result); } ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float)); @@ -387,8 +395,9 @@ static std::pair test_forward_backward( } else { for (int idata = 0; idata < ndata; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); - ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); } } @@ -492,14 +501,16 @@ static std::pair test_idata_split(ggml_backend_sched_t backend_sched, int idata = 0; for (; idata < idata_split; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); - ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); } for (; idata < ndata; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false); ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); - ggml_opt_forward(cd.opt_ctx, cd.result2); + ggml_opt_eval(cd.opt_ctx, cd.result2); ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); } } @@ -573,7 +584,6 @@ static std::pair test_gradient_accumulation( struct helper_ctx_data cd = helper_get_ctx_data( backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type); - struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); std::vector grad_history(ndata); for (int64_t idata = 0; idata < ndata; ++idata) { @@ -584,15 +594,17 @@ static std::pair test_gradient_accumulation( if (nbatch_physical == 1) { for (int idata = 0; idata < ndata; ++idata) { const float idataf = idata; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float)); - ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float)); } } else if (nbatch_physical == 2) { for (int idata = 0; idata < ndata; idata += 2) { const float idataf[2] = {float(idata + 0), float(idata + 1)}; + ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true); ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float)); - ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_opt_eval(cd.opt_ctx, cd.result); grad_history[idata + 0] = 0.0f; ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float)); @@ -617,7 +629,7 @@ static std::pair test_gradient_accumulation( } subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol); subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol); - subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol); } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { if (nbatch_physical == 1) { subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol); @@ -630,7 +642,7 @@ static std::pair test_gradient_accumulation( } subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol); subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol); - subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol); } else { GGML_ASSERT(false); } @@ -692,7 +704,8 @@ static std::pair test_regression(ggml_backend_sched_t backend_sched, g std::mt19937 gen(12345); std::normal_distribution nd{0.0f, 0.1f}; - ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init( + GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression); float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); @@ -733,15 +746,14 @@ static std::pair test_regression(ggml_backend_sched_t backend_sched, g struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); ggml_set_name(a, "a"); - ggml_set_param(ctx_static, a); + ggml_set_param(a); struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); ggml_set_name(b, "b"); - ggml_set_param(ctx_static, b); + ggml_set_param(b); struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b); ggml_set_name(f, "f"); - ggml_set_param(ctx_static, f); ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend); const float a0 = 1.0f; From de4c07f93783a1a96456a44dc16b9db538ee1618 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 12 May 2025 15:06:51 +0200 Subject: [PATCH 33/33] clip : cap max image size 1024 for qwen vl model (#13478) --- tools/mtmd/clip.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 0adf03163..41ba45a79 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1909,16 +1909,20 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_QWEN2VL: { - // max image size = sqrt(max_pixels) - // https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json - hparams.image_size = 3584; + // max image size = sqrt(max_pixels) = 3584 + // ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json + // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable + // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10 + hparams.image_size = 1024; hparams.warmup_image_size = hparams.patch_size * 8; } break; case PROJECTOR_TYPE_QWEN25VL: { // max image size = sqrt(max_pixels) // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json - hparams.image_size = 3584; + // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable + // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10 + hparams.image_size = 1024; hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break;