From 3ffbbd5ce130859be91909e9b77d4c1962a6be2c Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:14:11 +0800 Subject: [PATCH 01/19] HIP: rocWMMA documentation and enabling in workflow builds (#12179) * Enable rocWMMA for Windows CI build * Enable for Ubuntu * GGML_HIP_ROCWMMA_FATTN documentation work --- .github/workflows/build.yml | 16 ++++++++++++++++ docs/build.md | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b653b1f82..7e4596ab2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -467,6 +467,7 @@ jobs: run: | cmake -B build -S . \ -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DGGML_HIP_ROCWMMA_FATTN=ON \ -DGGML_HIP=ON cmake --build build --config Release -j $(nproc) @@ -476,6 +477,7 @@ jobs: cmake -B build2 -S . \ -DCMAKE_C_COMPILER=hipcc \ -DCMAKE_CXX_COMPILER=hipcc \ + -DGGML_HIP_ROCWMMA_FATTN=ON \ -DGGML_HIP=ON cmake --build build2 --config Release -j $(nproc) @@ -1202,6 +1204,11 @@ jobs: id: checkout uses: actions/checkout@v4 + - name: Clone rocWMMA repository + id: clone_rocwmma + run: | + git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + - name: Install id: depends run: | @@ -1231,8 +1238,10 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_CXX_FLAGS="-Irocwmma/library/include/" ` -DCMAKE_BUILD_TYPE=Release ` -DGGML_HIP=ON ` + -DGGML_HIP_ROCWMMA_FATTN=ON ` -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} @@ -1251,6 +1260,11 @@ jobs: with: fetch-depth: 0 + - name: Clone rocWMMA repository + id: clone_rocwmma + run: | + git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: @@ -1280,8 +1294,10 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_CXX_FLAGS="-Irocwmma/library/include/" ` -DCMAKE_BUILD_TYPE=Release ` -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` + -DGGML_HIP_ROCWMMA_FATTN=ON ` -DGGML_HIP=ON ` -DGGML_RPC=ON cmake --build build -j ${env:NUMBER_OF_PROCESSORS} diff --git a/docs/build.md b/docs/build.md index b3ecf043d..3d8333328 100644 --- a/docs/build.md +++ b/docs/build.md @@ -235,6 +235,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`. However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). + To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. + + The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager. + + As an alternative, you can manually install the library by cloning it from the official [GitHub repository](https://github.com/ROCm/rocWMMA), checkout the corresponding version tag (e.g. `rocm-6.2.4`) and set `-DCMAKE_CXX_FLAGS="-I/library/include/"` in CMake. This also works under Windows despite not officially supported by AMD. + Note that if you get the following error: ``` clang: error: cannot find ROCm device library; provide its path via '--rocm-path' or '--rocm-device-lib-path', or pass '-nogpulib' to build without ROCm device library From 5220a16d18563d3ffc509002f0514415fdda4036 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 6 Mar 2025 18:45:09 +0100 Subject: [PATCH 02/19] CUDA: fix FA logic for PTX 7.0 and CC >= 7.5 (#12222) --- ggml/src/ggml-cuda/fattn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 24f973056..2e72fc8fd 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (cc == GGML_CUDA_CC_VOLTA) { + if (fp16_mma_available(cc) && !new_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } From 3d652bfddfba09022525067e672c3c145c074649 Mon Sep 17 00:00:00 2001 From: Lucas Moura Belo Date: Thu, 6 Mar 2025 16:15:13 -0300 Subject: [PATCH 03/19] readme : update bindings (#12229) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d73b0495d..e371c44ed 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp) - Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift) - Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama) +- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi) From 776f9e59cc8a85e840d1d4af8540d199c77190ac Mon Sep 17 00:00:00 2001 From: xiaofei Date: Fri, 7 Mar 2025 06:58:25 +0800 Subject: [PATCH 04/19] cmake : fix undefined reference errors for std::filesystem in ggml (#12092) (#12094) Signed-off-by: Ray Lee Co-authored-by: Ray Lee --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index cfd4ac54c..52817510f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -236,7 +236,7 @@ add_library(ggml target_link_libraries(ggml PUBLIC ggml-base) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_libraries(ggml PRIVATE dl) + target_link_libraries(ggml PRIVATE dl stdc++fs) endif() function(ggml_add_backend_library backend) From d76a86d967ef491d530400b08bc8ef8a14807936 Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 6 Mar 2025 16:20:35 -0800 Subject: [PATCH 05/19] opencl: Noncontiguous `norm`, `rms_norm`, disable `fp16` for some ops (#12217) * opencl: support noncontiguous `norm` * opencl: support noncontiguous `rms_norm` * opencl: disable fp16 for `ADD`, `MUL`, `SCALE`, `RELU`, `GELU`, `SILU`, `CLAMP` --- ggml/src/ggml-opencl/ggml-opencl.cpp | 70 +++++++++++++-------- ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 26 ++++++-- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index bc2ea06b5..b85a895c4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_ADD: case GGML_OP_SCALE: case GGML_OP_MUL: - return true; + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: case GGML_OP_NORM: case GGML_OP_RMS_NORM: @@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const memcpy(&eps, dst->op_params, sizeof(float)); const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; - GGML_ASSERT(ggml_is_contiguous_1(src0)); + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; const int nth = MIN(64, ne00); cl_kernel kernel = backend_ctx->kernel_norm; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); - const int64_t nrows = ggml_nrows(src0); - - size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; #ifdef GGML_OPENCL_PROFILING @@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c memcpy(&eps, dst->op_params, sizeof(float)); const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); const int nth = MIN(64, ne00); - const int64_t nrows = ggml_nrows(src0); - - size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; cl_kernel kernel = backend_ctx->kernel_rms_norm; @@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c sizeof(local_work_size), local_work_size, sizeof(size_t), &sgs, NULL)); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); // This is local memory - the size depends on subgroup size. - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); #ifdef GGML_OPENCL_PROFILING cl_event evt; diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl index 8882a8c9c..1d43642a9 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -506,14 +506,23 @@ kernel void kernel_norm( global float * dst, ulong offsetd, int ne00, + int ne01, + int ne02, + int ne03, ulong nb01, + ulong nb02, + ulong nb03, float eps, local float * sum ) { src0 = (global void*)((global char*)src0 + offset0); dst = (global void*)((global char*)dst + offsetd); - global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01); + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); // MEAN // parallel sum @@ -533,7 +542,7 @@ kernel void kernel_norm( // recenter and VARIANCE barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + get_group_id(0)*ne00; + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { y[i00] = x[i00] - mean; @@ -566,14 +575,23 @@ kernel void kernel_rms_norm( global float * dst, ulong offsetd, int ne00, + int ne01, + int ne02, + int ne03, ulong nb01, + ulong nb02, + ulong nb03, float eps, local float * sum // Note, the size depends on number of subgroups ) { src0 = (global void*)((global char*)src0 + offset0); dst = (global float*)((global char*)dst + offsetd); - global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01); + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); global float * x_scalar = (global float *) x; float4 sumf = 0; float all_sum = 0; @@ -607,7 +625,7 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00); + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); global float * y_scalar = (global float *) y; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { y[i00] = x[i00] * scale; From d6c95b0740510231b3797b80d6d3440d8fe188b6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 7 Mar 2025 06:23:16 +0100 Subject: [PATCH 06/19] metal : fix default.metallib build (#12224) This commit updates the custom command to build the default.metallib file to use the correct path to ../ggml-common.h by using the variable METALLIB_COMMON. The motivation for this change is that currently when building and specifying GGML_METAL_EMBED_LIBRARY=OFF the following error is generated: ```console [ 11%] Linking CXX shared library ../../bin/libggml.dylib [ 11%] Built target ggml make[2]: *** No rule to make target `ggml/src/ggml-metal/ggml-common.h', needed by `bin/default.metallib'. Stop. make[1]: *** [ggml/src/ggml-metal/CMakeFiles/ggml-metal-lib.dir/all] Error 2 ``` With the above change the build could progress but there was a follow on error about not being able to find the ggml-common.h file in ggml-metal.metal where is was included as a relative path: ```console [ 11%] Compiling Metal kernels /Users/danbev/work/llama.cpp/build/bin/ggml-metal.metal:6:10: error: '../ggml-common.h' file not found, did you mean 'ggml-common.h'? ^~~~~~~~~~~~~~~~~~ "ggml-common.h" 1 error generated. ``` Removing the relative path then allowed the build to complete successfully. --- ggml/src/ggml-metal/CMakeLists.txt | 4 ++-- ggml/src/ggml-metal/ggml-metal.metal | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 89fcde2fa..be3fb3fa9 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -27,12 +27,12 @@ configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) +set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) add_compile_definitions(GGML_METAL_EMBED_LIBRARY) - set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") @@ -93,7 +93,7 @@ else() COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal - DEPENDS ggml-metal.metal ggml-common.h + DEPENDS ggml-metal.metal ${METALLIB_COMMON} COMMENT "Compiling Metal kernels" ) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d092a1690..c46a13050 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3,8 +3,7 @@ #if defined(GGML_METAL_EMBED_LIBRARY) __embed_ggml-common.h__ #else -// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift -#include "../ggml-common.h" +#include "ggml-common.h" #endif #include "ggml-metal-impl.h" From f1648e91cf6c52e9593810aa70857e412d474c09 Mon Sep 17 00:00:00 2001 From: David Huang <1969802+hjc4869@users.noreply.github.com> Date: Fri, 7 Mar 2025 15:06:08 +0800 Subject: [PATCH 07/19] HIP: fix rocWMMA build flags under Windows (#12230) --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7e4596ab2..f2c81c0c2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1238,7 +1238,7 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-Irocwmma/library/include/" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` -DCMAKE_BUILD_TYPE=Release ` -DGGML_HIP=ON ` -DGGML_HIP_ROCWMMA_FATTN=ON ` @@ -1294,7 +1294,7 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-Irocwmma/library/include/" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` -DCMAKE_BUILD_TYPE=Release ` -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` -DGGML_HIP_ROCWMMA_FATTN=ON ` From 5e2d57b2b2e43eadbe6d66ba3e873a824b95e725 Mon Sep 17 00:00:00 2001 From: BB-fat <45072480+BB-fat@users.noreply.github.com> Date: Fri, 7 Mar 2025 15:35:57 +0800 Subject: [PATCH 08/19] metal : simplify kernel arguments using a struct (#3229) (#12194) * metal : refactor im2col parameters into a struct * metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets * metal : refactor sum_rows parameters into a struct * metal : refactor soft_max parameters into a struct * metal : refactor diag_mask_inf parameters into a struct * metal : refactor ssm_conv parameters into a struct * metal : refactor ssm_scan parameters into a struct * metal : refactor get_rows parameters into a struct * metal : refactor group_norm parameters into a struct * metal : refactor conv_transpose_1d parameters into a struct * metal : refactor upscale parameters into a struct * metal : refactor pad parameters into a struct * metal : refactor pad_reflect_1d parameters into a struct * metal : refactor arange parameters into a struct * metal : refactor timestep_embedding parameters into a struct * metal : refactor argsort parameters into a struct * metal : refactor leaky_relu parameters into a struct * metal : refactor pool_2d parameters into a struct * metal : fix trailing whitespace --------- Co-authored-by: alexju --- ggml/src/ggml-metal/ggml-metal-impl.h | 235 ++++++++++ ggml/src/ggml-metal/ggml-metal.m | 466 ++++++++++--------- ggml/src/ggml-metal/ggml-metal.metal | 627 ++++++++------------------ 3 files changed, 685 insertions(+), 643 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e3dc25f16..a58c474eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -285,4 +285,239 @@ typedef struct { float eps; } ggml_metal_kargs_rms_norm; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t n_groups; + float eps; +} ggml_metal_kargs_group_norm; + +typedef struct { + int32_t IC; + int32_t IL; + int32_t K; + int32_t s0; + uint64_t nb0; + uint64_t nb1; +} ggml_metal_kargs_conv_transpose_1d; + +typedef struct { + uint64_t ofs0; + uint64_t ofs1; + int32_t IW; + int32_t IH; + int32_t CHW; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int32_t d0; + int32_t d1; + int32_t N; + int32_t KH; + int32_t KW; + int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources +} ggml_metal_kargs_im2col; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne10; + int64_t ne11; + int64_t ne12; + int64_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_sum_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; +} ggml_metal_kargs_soft_max; + +typedef struct { + int64_t ne00; + int64_t ne01; + int n_past; +} ggml_metal_kargs_diag_mask_inf; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + int64_t ne11; + uint64_t nb10; + uint64_t nb11; + int64_t ne0; + int64_t ne1; + int64_t ne2; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_ssm_conv; + +typedef struct { + int64_t d_state; + int64_t d_inner; + int64_t n_seq_tokens; + int64_t n_seqs; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb30; + uint64_t nb31; + uint64_t nb40; + uint64_t nb41; + uint64_t nb42; + uint64_t nb50; + uint64_t nb51; + uint64_t nb52; +} ggml_metal_kargs_ssm_scan; + +typedef struct { + int64_t ne00; + uint64_t nb01; + uint64_t nb02; + int64_t ne10; + uint64_t nb10; + uint64_t nb11; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_get_rows; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float sf0; + float sf1; + float sf2; + float sf3; +} ggml_metal_kargs_upscale; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_pad; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t p0; + int32_t p1; +} ggml_metal_kargs_pad_reflect_1d; + +typedef struct { + uint64_t nb1; + int dim; + int max_period; +} ggml_metal_kargs_timestep_embedding; + +typedef struct { + float slope; +} ggml_metal_kargs_leaky_relu; + +typedef struct { + int64_t ncols; + int64_t ncols_pad; +} ggml_metal_kargs_argsort; + +typedef struct { + int64_t ne0; + float start; + float step; +} ggml_metal_kargs_arange; + +typedef struct { + int32_t k0; + int32_t k1; + int32_t s0; + int32_t s1; + int32_t p0; + int32_t p1; + int64_t IH; + int64_t IW; + int64_t OH; + int64_t OW; + int64_t parallel_elements; +} ggml_metal_kargs_pool_2d; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1f45ebad1..1158b285c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - // TODO: add ggml_metal_kargs struct + + ggml_metal_kargs_sum_rows args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // TODO: add ggml_metal_kargs struct - // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; @@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_diag_mask_inf args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.n_past =*/ n_past, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; if (ne00%8 == 0) { [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; @@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_ssm_conv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_ssm_scan args = { + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb30 =*/ nb30, + /*.nb31 =*/ nb31, + /*.nb40 =*/ nb40, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb50 =*/ nb50, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:7]; [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_get_rows args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; @@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_group_norm args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.n_groups =*/ n_groups, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node( const int32_t CHW = IC * KH * KW; - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; @@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_im2col args = { + /*.ofs0 =*/ ofs0, + /*.ofs1 =*/ ofs1, + /*.IW =*/ IW, + /*.IH =*/ IH, + /*.CHW =*/ CHW, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, + /*.N =*/ N, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.KHW =*/ KH * KW, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; if (is_gt_mttpt) { - [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; - [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; - [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; - const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); @@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + ggml_metal_kargs_conv_transpose_1d args = { + /*.IC =*/ IC, + /*.IL =*/ IL, + /*.K =*/ K, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); @@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_pad args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, ne0); @@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; + ggml_metal_kargs_pad_reflect_1d args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.p0 =*/ p0, + /*.p1 =*/ p1 + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14]; - [encoder setBytes:&p0 length:sizeof(p0) atIndex:15]; - [encoder setBytes:&p1 length:sizeof(p1) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, ne0); @@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_arange args = { + /*.ne0 =*/ ne0, + /*.start =*/ start, + /*.step =*/ step + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:1]; const int nth = MIN(1024, ne0); @@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_timestep_embedding args = { + /*.nb1 =*/ nb1, + /*.dim =*/ dim, + /*.max_period =*/ max_period + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int nth = MIN(1024, half); @@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_argsort args = { + /*.ncols =*/ ne00, + /*.ncols_pad =*/ ne00_padded + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; @@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_leaky_relu args = { + /*.slope =*/ slope + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; const int64_t n = ggml_nelements(dst); @@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; - // TODO: add ggml_metal_kargs struct + ggml_metal_kargs_pool_2d args_pool_2d = { + /* .k0 = */ k0, + /* .k1 = */ k1, + /* .s0 = */ s0, + /* .s1 = */ s1, + /* .p0 = */ p0, + /* .p1 = */ p1, + /* .IH = */ IH, + /* .IW = */ IW, + /* .OH = */ OH, + /* .OW = */ OW, + /* .parallel_elements = */ parallel_elements + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; - [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; - [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; - [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; - [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; - [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; - [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; - [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; - [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; - [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; - [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c46a13050..ad9d42a3e 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -947,45 +947,22 @@ kernel void kernel_cos( kernel void kernel_sum_rows( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, + constant ggml_metal_kargs_sum_rows & args, uint3 tpig[[thread_position_in_grid]]) { int64_t i3 = tpig.z; int64_t i2 = tpig.y; int64_t i1 = tpig.x; - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { return; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); float row_sum = 0; - for (int64_t i0 = 0; i0 < ne00; i0++) { + for (int64_t i0 = 0; i0 < args.ne00; i0++) { row_sum += src_row[i0]; } @@ -997,36 +974,29 @@ kernel void kernel_soft_max( device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, + constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); - device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); float slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const int64_t h = i02; - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exp); } @@ -1034,8 +1004,8 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -1059,8 +1029,8 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -1090,7 +1060,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { pdst[i00] *= inv_sum; } } @@ -1100,35 +1070,28 @@ kernel void kernel_soft_max_4( device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, + constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + const int64_t i03 = (tgpig) / (args.ne02*args.ne01); + const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; + const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); - device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; float slope = 1.0f; - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const int64_t h = i02; - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exp); } @@ -1136,8 +1099,8 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -1162,8 +1125,8 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -1195,7 +1158,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { pdst4[i00] *= inv_sum; } } @@ -1211,27 +1174,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, + constant ggml_metal_kargs_diag_mask_inf & args, uint3 tpig[[thread_position_in_grid]]) { const int64_t i02 = tpig[2]; const int64_t i01 = tpig[1]; const int64_t i00 = tpig[0]; - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + if (i00 > args.n_past + i01) { + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY; } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00]; } } kernel void kernel_diag_mask_inf_8( device const float4 * src0, device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, + constant ggml_metal_kargs_diag_mask_inf & args, uint3 tpig[[thread_position_in_grid]]) { const int64_t i = 2*tpig[0]; @@ -1239,42 +1198,26 @@ kernel void kernel_diag_mask_inf_8( dst[i+0] = src0[i+0]; dst[i+1] = src0[i+1]; int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01; + const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00; const int64_t i00 = i4; for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { + if (i00 + 4 + k <= args.n_past + i01) { break; } dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { + if (i00 + k > args.n_past + i01) { dst[i][k] = -INFINITY; } } } // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -// TODO: optimize kernel void kernel_ssm_conv_f32( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_ssm_conv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -1282,15 +1225,15 @@ kernel void kernel_ssm_conv_f32( const int64_t i2 = tgpig.y; const int64_t i3 = tgpig.z; - const int64_t nc = ne10; - //const int64_t ncs = ne00; - //const int64_t nr = ne01; - //const int64_t n_t = ne1; - //const int64_t n_s = ne2; + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; - device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); - device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); - device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); float sumf = 0.0f; @@ -1302,7 +1245,6 @@ kernel void kernel_ssm_conv_f32( } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32 -// TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1311,48 +1253,27 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device float * dst, - constant int64_t & d_state, - constant int64_t & d_inner, - constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb20, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb30, - constant uint64_t & nb31, - constant uint64_t & nb40, - constant uint64_t & nb41, - constant uint64_t & nb42, - constant uint64_t & nb50, - constant uint64_t & nb51, - constant uint64_t & nb52, + constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { const int64_t ir = tgpig.x; const int64_t i3 = tgpig.y; - const int64_t nc = d_state; - //const int64_t nr = d_inner; - const int64_t n_t = n_seq_tokens; - //const int64_t n_s = n_seqs; + const int64_t nc = args.d_state; + // const int64_t nr = args.d_inner; + const int64_t n_t = args.n_seq_tokens; + // const int64_t n_s = args.n_seqs; for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); + device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); if (i2 > 0) { s0 = s; @@ -1545,22 +1466,15 @@ kernel void kernel_rms_norm( kernel void kernel_group_norm( device const float * src0, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, + constant ggml_metal_kargs_group_norm & args, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + const int64_t ne = args.ne00*args.ne01*args.ne02; + const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups); int start = tgpig * gs; int end = start + gs; @@ -1624,7 +1538,7 @@ kernel void kernel_group_norm( } const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); + const float scale = 1.0f/sqrt(variance + args.eps); for (int j = start; j < end; j += ntg) { dst[j] *= scale; } @@ -2588,17 +2502,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_ typedef void (im2col_t)( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2608,17 +2512,7 @@ template kernel void kernel_im2col( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2639,17 +2533,17 @@ kernel void kernel_im2col( const int64_t ioh = tgpig[1]; const int64_t iow = tgpig[2]; - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw); + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { pdst[offset_dst] = 0.0f; } else { - const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw; + const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; pdst[offset_dst] = x[offset_src]; } } @@ -2660,20 +2554,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; typedef void (im2col_ext_t)( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - constant int32_t & N, - constant int32_t & KH, - constant int32_t & KW, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2683,53 +2564,40 @@ template kernel void kernel_im2col_ext( device const float * x, device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - constant int32_t & N, - constant int32_t & KH, - constant int32_t & KW, + constant ggml_metal_kargs_im2col & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + const int64_t KHW = (int64_t)args.KHW; - const int64_t d = tgpig[0] / CHW; - const int64_t chw = tgpig[0] % CHW; + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) const int64_t HW = tgpig[0] % KHW; const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= N) { + if (tpitg_0 >= args.N) { return; } - const int64_t tpitg_1 = HW / KW; - const int64_t tpitg_2 = HW % KW; + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; - const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; - const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); device T * pdst = (device T *) (dst); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { pdst[offset_dst] = 0.0f; } else { - const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; } } @@ -2740,12 +2608,7 @@ typedef void (conv_transpose_1d_t)( device const float * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); @@ -2754,29 +2617,24 @@ kernel void kernel_conv_transpose_1d( device const T * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { float v = 0.0f; - for (int64_t c = 0; c < IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; - const int32_t input_offset = c * IL; + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; + const int32_t input_offset = c * args.IL; - for (int64_t i = 0; i < IL; i++) { - if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { - v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; + for (int64_t i = 0; i < args.IL; i++) { + if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { + v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; } } } - device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); + device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1); dst_ptr[0] = v; } @@ -2786,12 +2644,7 @@ kernel void kernel_conv_transpose_1d( device const float * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); @@ -2800,38 +2653,14 @@ kernel void kernel_conv_transpose_1d( device const half * src0, device const float * src1, device char * dst, - constant int32_t & IC, - constant int32_t & IL, - constant int32_t & K, - constant int32_t & s0, - constant uint64_t & nb0, - constant uint64_t & nb1, + constant ggml_metal_kargs_conv_transpose_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); kernel void kernel_upscale_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & sf0, - constant float & sf1, - constant float & sf2, - constant float & sf3, + constant ggml_metal_kargs_upscale & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2840,15 +2669,15 @@ kernel void kernel_upscale_f32( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - const int64_t i03 = i3/sf3; - const int64_t i02 = i2/sf2; - const int64_t i01 = i1/sf1; + const int64_t i03 = i3/args.sf3; + const int64_t i02 = i2/args.sf2; + const int64_t i01 = i1/args.sf1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int64_t i00 = i0/sf0; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int64_t i00 = i0/args.sf0; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_ptr[0] = src0_ptr[0]; } @@ -2857,22 +2686,7 @@ kernel void kernel_upscale_f32( kernel void kernel_pad_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, + constant ggml_metal_kargs_pad & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -2885,12 +2699,12 @@ kernel void kernel_pad_f32( const int64_t i02 = i2; const int64_t i01 = i1; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00) { dst_ptr[i0] = src0_ptr[i0]; } else { dst_ptr[i0] = 0.0f; @@ -2900,7 +2714,7 @@ kernel void kernel_pad_f32( return; } - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { dst_ptr[i0] = 0.0f; } } @@ -2908,21 +2722,7 @@ kernel void kernel_pad_f32( kernel void kernel_pad_reflect_1d_f32( device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant int64_t & ne0, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & p0, - constant int32_t & p1, + constant ggml_metal_kargs_pad_reflect_1d & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2936,17 +2736,17 @@ kernel void kernel_pad_reflect_1d_f32( const int64_t i02 = i2; const int64_t i01 = i1; - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < p0) { - dst_ptr[i0] = src0_ptr[p0 - i0]; - } else if (i0 < ne0 - p1) { - dst_ptr[i0] = src0_ptr[i0 - p0]; + if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.p0) { + dst_ptr[i0] = src0_ptr[args.p0 - i0]; + } else if (i0 < args.ne0 - args.p1) { + dst_ptr[i0] = src0_ptr[i0 - args.p0]; } else { - dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1]; + dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1]; } } } @@ -2954,44 +2754,40 @@ kernel void kernel_pad_reflect_1d_f32( kernel void kernel_arange_f32( device char * dst, - constant int64_t & ne0, - constant float & start, - constant float & step, + constant ggml_metal_kargs_arange & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { device float * dst_ptr = (device float *) dst; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = start + step * i0; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_ptr[i0] = args.start + args.step * i0; } } kernel void kernel_timestep_embedding_f32( device const char * src0, device char * dst, - constant uint64_t & nb1, - constant int & dim, - constant int & max_period, + constant ggml_metal_kargs_timestep_embedding & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { int i = tgpig.x; - device float * embed_data = (device float *)(dst + i*nb1); + device float * embed_data = (device float *)(dst + i*args.nb1); - int half_ = dim / 2; + int half_ = args.dim / 2; for (int j = tpitg.x; j < half_; j += ntg.x) { float timestep = ((device float *)src0)[i]; - float freq = (float)exp(-log((float)max_period) * j / half_); + float freq = (float)exp(-log((float)args.max_period) * j / half_); float arg = timestep * freq; embed_data[j ] = cos(arg); embed_data[j + half_] = sin(arg); } - if (dim % 2 != 0 && tpitg.x == 0) { - embed_data[dim] = 0.f; + if (args.dim % 2 != 0 && tpitg.x == 0) { + embed_data[args.dim] = 0.f; } } @@ -2999,8 +2795,7 @@ kernel void kernel_timestep_embedding_f32( typedef void (argsort_t)( device const float * x, device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, + constant ggml_metal_kargs_argsort & args, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -3009,8 +2804,7 @@ template kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, + constant ggml_metal_kargs_argsort & args, threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { @@ -3018,9 +2812,9 @@ kernel void kernel_argsort_f32_i32( int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols_pad) return; + if (col >= args.ncols_pad) return; - device const float * x_row = x + row * ncols; + device const float * x_row = x + row * args.ncols; threadgroup int32_t * dst_row = shared_values; // initialize indices @@ -3028,21 +2822,21 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols_pad; k *= 2) { + for (int k = 2; k <= args.ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + if (dst_row[col] >= args.ncols || + (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]])) ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + if (dst_row[ixj] >= args.ncols || + (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]])) ) { @@ -3055,8 +2849,8 @@ kernel void kernel_argsort_f32_i32( } // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (col < args.ncols) { + dst[row * args.ncols + col] = dst_row[col]; } } @@ -3066,9 +2860,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar kernel void kernel_leaky_relu_f32( device const float * src0, device float * dst, - constant float & slope, + constant ggml_metal_kargs_leaky_relu & args, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; } // ref: https://arxiv.org/pdf/2307.08691.pdf @@ -6009,28 +5803,21 @@ kernel void kernel_get_rows_q( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; } } @@ -6039,27 +5826,20 @@ kernel void kernel_get_rows_f( device const void * src0, device const void * src1, device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; } } @@ -6067,27 +5847,20 @@ kernel void kernel_get_rows_i32( device const void * src0, device const void * src1, device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, + constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int64_t i02 = i11; - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; } } @@ -6689,98 +6462,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel void kernel_pool_2d_max_f32( device const float * src0, device float * dst, - constant int32_t & k0, - constant int32_t & k1, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int64_t & IH, - constant int64_t & IW, - constant int64_t & OH, - constant int64_t & OW, - constant int64_t & parallel_elements, + constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= parallel_elements) { + if (gid >= args.parallel_elements) { return; } const int idx = gid; - const int I_HW = IH * IW; - const int O_HW = OH * OW; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; const int nc = idx / O_HW; - const int cur_oh = idx % O_HW / OW; - const int cur_ow = idx % O_HW % OW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; device const float * i_ptr = src0 + nc * I_HW; device float * o_ptr = dst + nc * O_HW; - const int start_h = cur_oh * s1 - p1; + const int start_h = cur_oh * args.s1 - args.p1; const int bh = MAX(0, start_h); - const int eh = MIN(IH, start_h + k1); - const int start_w = cur_ow * s0 - p0; + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; const int bw = MAX(0, start_w); - const int ew = MIN(IW, start_w + k0); + const int ew = MIN(args.IW, start_w + args.k0); float res = -INFINITY; for (int i = bh; i < eh; i += 1) { for (int j = bw; j < ew; j += 1) { - res = MAX(res, i_ptr[i * IW + j]); + res = MAX(res, i_ptr[i * args.IW + j]); } } - o_ptr[cur_oh * OW + cur_ow] = res; + o_ptr[cur_oh * args.OW + cur_ow] = res; } kernel void kernel_pool_2d_avg_f32( device const float * src0, device float * dst, - constant int32_t & k0, - constant int32_t & k1, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int64_t & IH, - constant int64_t & IW, - constant int64_t & OH, - constant int64_t & OW, - constant int64_t & parallel_elements, + constant ggml_metal_kargs_pool_2d & args, uint gid[[thread_position_in_grid]]) { - if (gid >= parallel_elements) { + if (gid >= args.parallel_elements) { return; } const int idx = gid; - const int I_HW = IH * IW; - const int O_HW = OH * OW; + const int I_HW = args.IH * args.IW; + const int O_HW = args.OH * args.OW; const int nc = idx / O_HW; - const int cur_oh = idx % O_HW / OW; - const int cur_ow = idx % O_HW % OW; + const int cur_oh = idx % O_HW / args.OW; + const int cur_ow = idx % O_HW % args.OW; device const float * i_ptr = src0 + nc * I_HW; device float * o_ptr = dst + nc * O_HW; - const int start_h = cur_oh * s1 - p1; + const int start_h = cur_oh * args.s1 - args.p1; const int bh = MAX(0, start_h); - const int eh = MIN(IH, start_h + k1); - const int start_w = cur_ow * s0 - p0; + const int eh = MIN(args.IH, start_h + args.k1); + const int start_w = cur_ow * args.s0 - args.p0; const int bw = MAX(0, start_w); - const int ew = MIN(IW, start_w + k0); + const int ew = MIN(args.IW, start_w + args.k0); // const float scale = 1. / ((eh - bh) * (ew - bw)); - const float scale = 1. / (k0 * k1); + const float scale = 1. / (args.k0 * args.k1); float res = 0; for (int i = bh; i < eh; i += 1) { for (int j = bw; j < ew; j += 1) { - float cur = i_ptr[i * IW + j]; + float cur = i_ptr[i * args.IW + j]; res += cur * scale; } } - o_ptr[cur_oh * OW + cur_ow] = res; + o_ptr[cur_oh * args.OW + cur_ow] = res; } From 7cf64f6beecf54c6ac71503181f154667fd4228a Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 7 Mar 2025 09:33:37 +0000 Subject: [PATCH 09/19] sync: minja - support QwQ-32B (#12235) https://github.com/google/minja/commit/8a76f7815e8a3ae00bd233c2b5a8b7d4e86564ec --- common/minja/minja.hpp | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp index c58dd66e0..fa4c34d6e 100644 --- a/common/minja/minja.hpp +++ b/common/minja/minja.hpp @@ -1378,13 +1378,27 @@ struct ArgumentsExpression { } }; -static std::string strip(const std::string & s) { - auto start = s.find_first_not_of(" \t\n\r"); +static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { + auto charset = chars.empty() ? " \t\n\r" : chars; + auto start = left ? s.find_first_not_of(charset) : 0; if (start == std::string::npos) return ""; - auto end = s.find_last_not_of(" \t\n\r"); + auto end = right ? s.find_last_not_of(charset) : s.size() - 1; return s.substr(start, end - start + 1); } +static std::vector split(const std::string & s, const std::string & sep) { + std::vector result; + size_t start = 0; + size_t end = s.find(sep); + while (end != std::string::npos) { + result.push_back(s.substr(start, end - start)); + start = end + sep.length(); + end = s.find(sep, start); + } + result.push_back(s.substr(start)); + return result; +} + static std::string capitalize(const std::string & s) { if (s.empty()) return s; auto result = s; @@ -1467,8 +1481,26 @@ public: } else if (obj.is_string()) { auto str = obj.get(); if (method->get_name() == "strip") { - vargs.expectArgs("strip method", {0, 0}, {0, 0}); - return Value(strip(str)); + vargs.expectArgs("strip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars)); + } else if (method->get_name() == "lstrip") { + vargs.expectArgs("lstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ true, /* right= */ false)); + } else if (method->get_name() == "rstrip") { + vargs.expectArgs("rstrip method", {0, 1}, {0, 0}); + auto chars = vargs.args.empty() ? "" : vargs.args[0].get(); + return Value(strip(str, chars, /* left= */ false, /* right= */ true)); + } else if (method->get_name() == "split") { + vargs.expectArgs("split method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = split(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str)); From 8fad3c7a7c54a25a1ca38dfb08244df55288e675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Fri, 7 Mar 2025 11:15:33 +0100 Subject: [PATCH 10/19] server : Log original chat template parsing error (#12233) --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2a526b0e7..e1371dbf8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1900,6 +1900,7 @@ struct server_context { try { common_chat_format_example(chat_templates.get(), params.use_jinja); } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); chat_templates = common_chat_templates_init(model, "chatml"); } From ea002810a209246d034d1b6ddac387f778751588 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Mar 2025 12:19:31 +0200 Subject: [PATCH 11/19] ci : fix save-load test invocations (#12245) --- ci/run.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 77c32ce00..9fc19c89d 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -352,10 +352,10 @@ function gg_run_open_llama_7b_v2 { (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" From 68d0027f3d19eb579c1863814c91e37ffa699014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Fri, 7 Mar 2025 12:54:22 +0100 Subject: [PATCH 12/19] ggml-cpu: faster AVX2 variant for IQ1_M (#12216) --- ggml/src/ggml-cpu/ggml-cpu-quants.c | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 2ae66591d..8c7dbd1cc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -11718,9 +11718,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const #elif defined __AVX2__ - const __m256i mask = _mm256_set1_epi16(2 * 0x7); + const __m256i mask = _mm256_set1_epi16(0x7); const __m256i mone = _mm256_set1_epi16(1); const __m256i mone8 = _mm256_set1_epi8(1); + const __m256i mtwo8 = _mm256_set1_epi8(2); + // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half. + const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0); __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); @@ -11732,6 +11735,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const uint16_t * sc = (const uint16_t *)x[i].scales; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + // Extract 3-bit scales (16 values) + __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc); + scales = _mm256_srlv_epi64(scales, scales_shift); + scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone); + + // Indices to repeat each scale 8 times. + __m256i scales_idx1 = _mm256_set1_epi16(0x0100); + __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8)); __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); @@ -11777,11 +11788,12 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 2), _mm_set1_epi16(sc[ib/2] << 1)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 8), _mm_set1_epi16(sc[ib/2] >> 5)); + __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1); + __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2); + + scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8); + scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8); - scale1 = _mm256_add_epi16(_mm256_and_si256(scale1, mask), mone); - scale2 = _mm256_add_epi16(_mm256_and_si256(scale2, mask), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); const __m256i p2 = _mm256_madd_epi16(dot2, scale2); const __m256i p3 = _mm256_madd_epi16(dot3, scale1); From d6ae2fa06139e496880cbf65197c84341e9d98e7 Mon Sep 17 00:00:00 2001 From: vmobilis <75476228+vmobilis@users.noreply.github.com> Date: Fri, 7 Mar 2025 11:11:40 +0300 Subject: [PATCH 13/19] ggml : ggml_compute_forward_concat() for arbitrary tensor type (ggml/1118) * ggml_compute_forward_concat() for arbitrary tensor type * Check that tensors' type match * ggml-cpu.c: check type of source tensors * ggml-cpu.c: move tensor type check to ggml_compute_forward_concat() * ggml.c: check concatenated tensor type * Remove tensor type check from ggml_compute_forward_concat() in ggml-cpu.c ..., as it was moved to ggml.c. --- ggml/src/ggml-cpu/ggml-cpu.c | 143 ++++++++++++++++++++++++++++++++++- ggml/src/ggml.c | 1 + 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c67fdd045..f2ab4c5d6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -6648,6 +6648,135 @@ static void ggml_compute_forward_repeat_back( // ggml_compute_forward_concat +static void ggml_compute_forward_concat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + const size_t len = ggml_type_size(src0->type); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const char * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03; + } else { + x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13; + } + + char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3; + + memcpy(y, x, len); + } + } + } + } +} + +static void ggml_compute_forward_concat_i8( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const int8_t * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_concat_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const ggml_fp16_t * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + static void ggml_compute_forward_concat_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -6655,7 +6784,7 @@ static void ggml_compute_forward_concat_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -6698,6 +6827,16 @@ static void ggml_compute_forward_concat( const struct ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_I16: + { + ggml_compute_forward_concat_f16(params, dst); + } break; + case GGML_TYPE_I8: + { + ggml_compute_forward_concat_i8(params, dst); + } break; case GGML_TYPE_F32: case GGML_TYPE_I32: { @@ -6705,7 +6844,7 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 084240331..89409bb0e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2332,6 +2332,7 @@ struct ggml_tensor * ggml_concat( struct ggml_tensor * b, int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + GGML_ASSERT(a->type == b->type); int64_t ne[GGML_MAX_DIMS]; for (int d = 0; d < GGML_MAX_DIMS; ++d) { From 102ac1891db32c346a7b6b96145a2a23c1e4c352 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Mar 2025 14:00:27 +0200 Subject: [PATCH 14/19] sync : ggml ggml-ci --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 040b53ca3..c7944d1d4 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -58ecf6b96d887e408b6869915863fa1126483d51 +c7dfe3d174f98b14801f9ed12f129179d3e7b638 From 7c7f3b7f435f41f2508e0e3010f0013cd8335156 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Fri, 7 Mar 2025 14:15:27 +0100 Subject: [PATCH 15/19] ggml : skip intermediate .air file when compiling .metallib (#12247) This commit updates the compilation of default.metallib to skip the intermediate .air (Apple Intermediate Representation) file. The motivation for this change is to simplify the custom command a little and avoid generating and then removing the .air file. --- ggml/src/ggml-metal/CMakeLists.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index be3fb3fa9..e22232780 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -88,9 +88,8 @@ else() add_custom_command( OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air - COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air + COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - | + xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal DEPENDS ggml-metal.metal ${METALLIB_COMMON} From 7ab364390f92b0b8d83f69821a536b424838f3f8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Mar 2025 20:54:30 +0200 Subject: [PATCH 16/19] server : infill gen ends on new line (#12254) --- examples/server/server.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e1371dbf8..8386f4eeb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1312,7 +1312,7 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } - bool can_batch_with(server_slot & other_slot) { + bool can_batch_with(server_slot & other_slot) const { return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora); } @@ -2157,14 +2157,6 @@ struct server_context { } if (slot.has_new_line) { - // if we have already seen a new line, we stop after a certain time limit - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); - } - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent if (slot.params.n_indent > 0) { // check the current indentation @@ -2203,6 +2195,14 @@ struct server_context { // check if there is a new line in the generated text if (result.text_to_send.find('\n') != std::string::npos) { slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } } // if context shift is disabled, we stop when it reaches the context limit From 6fefc05a7a4e676780ae10b0a4d0728e5281f367 Mon Sep 17 00:00:00 2001 From: "Jason C.H" Date: Sun, 9 Mar 2025 00:02:39 +0800 Subject: [PATCH 17/19] ggml-backend : make path_str compatible with C++20 (#12269) --- AUTHORS | 1 + ggml/src/ggml-backend-reg.cpp | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/AUTHORS b/AUTHORS index 6796b2941..ddcb15638 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1045,3 +1045,4 @@ zrm 蕭澧邦 <45505768+shou692199@users.noreply.github.com> 谢乃闻 Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> +Jason C.H diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index d0d68becd..9bedeae78 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -76,7 +76,14 @@ namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { std::string u8path; try { +#if defined(__cpp_lib_char8_t) + // C++20 and later: u8string() returns std::u8string + std::u8string u8str = path.u8string(); + u8path = std::string(reinterpret_cast(u8str.c_str())); +#else + // C++17: u8string() returns std::string u8path = path.u8string(); +#endif } catch (...) { } return u8path; From 0fd7ca7a210bd4abc995cd728491043491dbdef7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 8 Mar 2025 18:26:00 +0200 Subject: [PATCH 18/19] authors : update (#12271) --- AUTHORS | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index ddcb15638..0af9f44ad 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,4 @@ -# date: Tue Feb 4 13:04:05 EET 2025 +# date: Sat Mar 8 18:23:52 EET 2025 # this file is auto-generated by scripts/gen-authors.sh 0cc4m @@ -8,10 +8,12 @@ 3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> 44670 <44670@users.noreply.github.com> 65a <10104049+65a@users.noreply.github.com> +708-145 <40387547+708-145@users.noreply.github.com> AN Long AT Aarni Koskela Aaron Miller +Aaron Teo <57927438+taronaeo@users.noreply.github.com> Aaryaman Vasishta Abheek Gulati Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> @@ -20,6 +22,7 @@ Adithya Balaji AdithyanI Adrian Adrian Hesketh +Adrian Kretz Adrien Gallouët Adrien Gallouët Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> @@ -28,15 +31,18 @@ AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> AidanBeltonS Aisuko Akarshan Biswas +Akarshan Biswas Akarshan Biswas Al Mochkin <14274697+amochkin@users.noreply.github.com> Albert Jin Alberto <57916483+albbus-stack@users.noreply.github.com> Alberto Cabrera Pérez Alberto Cabrera Pérez +Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Alex Alex Azarov Alex Azarov +Alex Brooks Alex Klinkhamer Alex Klinkhamer Alex Nguyen @@ -67,6 +73,7 @@ Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com> Andy Salerno Andy Tai Anthony Van de Gejuchte +Antoine Viallon Antonis Makropoulos Arik Poznanski Armen Kaleshian @@ -83,6 +90,7 @@ Atsushi Tatsuma Austin <77757836+teleprint-me@users.noreply.github.com> AustinMroz BADR +BB-fat <45072480+BB-fat@users.noreply.github.com> Bach Le Bailey Chittle <39804642+bachittle@users.noreply.github.com> BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> @@ -101,6 +109,7 @@ Bert Wagner Billel Mokeddem Bingan <70050083+binganao@users.noreply.github.com> Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com> +Bodhi <3882561+BodhiHu@users.noreply.github.com> Bodo Graumann Bono Lv Borislav Stanimirov @@ -128,6 +137,7 @@ CentricStorm Chad Brewbaker Changyeon Kim Chao Jiang +Charles Duffy Charles Xu <63788048+chaxu01@users.noreply.github.com> Charles Xu Chen Xi @@ -139,12 +149,14 @@ Chris Kuehl Christian Demsar Christian Demsar Christian Falch <875252+chrfalch@users.noreply.github.com> +Christian Fillion Christian Kastner Christian Kögler Christian Köhnenkamp Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> Christopher Nielsen <62156882+mascguy@users.noreply.github.com> Clark Saben <76020733+csaben@users.noreply.github.com> +Clauszy Clint Herron Conrad Kramer Corentin REGAL @@ -163,6 +175,7 @@ Daniel Hiltgen Daniel Illescas Romero Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Daniele <57776841+daniandtheweb@users.noreply.github.com> +Danny Milosavljevic DannyDaemonic Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> Dave @@ -170,6 +183,7 @@ Dave Airlie Dave Airlie Dave Della Costa David Friehs +David Huang <1969802+hjc4869@users.noreply.github.com> David Kennedy David Pflug David Renshaw @@ -236,6 +250,7 @@ Felix Finn Voorhees Firat FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> +Florent BENOIT Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> Francisco Melo <43780565+francis2tm@users.noreply.github.com> @@ -254,6 +269,7 @@ Gary Mulder Gavin Zhao Genkagaku.GPT Georgi Gerganov +Gian-Carlo Pascutto Gilad S Gilad S. <7817232+giladgd@users.noreply.github.com> Giuseppe Scrivano @@ -267,7 +283,9 @@ Guspan Tanadi <36249910+guspan-tanadi@users.noreply.github.com> Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> Haggai Nuchi Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> +Hale Chan Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com> +Han Yin HanishKVC Haohui Mai Haoxiang Fei @@ -278,6 +296,7 @@ Haus1 Henk Poley Henri Vasserman Henrik Forstén +Henry Linjamäki Herman Semenov Hesen Peng HimariO @@ -307,6 +326,7 @@ Ivan Ivan Filipov <159561759+vanaka11@users.noreply.github.com> Ivan Komarov Ivan Stepanov +JC <43374599+MrSMlT@users.noreply.github.com> JFLFY2255 JH23X <165871467+JH23X@users.noreply.github.com> Jack Mousseau @@ -325,6 +345,7 @@ Jan Ploski Jannis Schönleber Jared Van Bortel Jared Van Bortel +Jason C.H Jason McCartney Jason Stillerman Jean-Christophe Hoelt @@ -342,6 +363,7 @@ Jiahao Li Jian Liao JidongZhang-THU <1119708529@qq.com> Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com> +Jinyang He Jiří Podivín <66251151+jpodivin@users.noreply.github.com> Jiří Sejkora Joan Fontanals @@ -379,6 +401,7 @@ Justine Tunney Juuso Alasuutari KASR Kamil Tomšík +Kante Yin Karol Kontny <82021046+kkontny@users.noreply.github.com> Karsten Weiss Karthick @@ -419,6 +442,7 @@ LoganDark Loïc Carrère LostRuins <39025047+LostRuins@users.noreply.github.com> LostRuins Concedo <39025047+LostRuins@users.noreply.github.com> +Lucas Moura Belo Luciano Luo Tian Lyle Dean @@ -463,6 +487,7 @@ Matthew Tejo Matvey Soloviev Max Krasnyansky Max Krasnyansky +Maxim Evtush <154841002+maximevtush@users.noreply.github.com> Maxime <672982+maximegmd@users.noreply.github.com> Maximilian Winter Meng Zhang @@ -494,6 +519,7 @@ Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> Mohammadreza Hendiani Mohammadreza Hendiani Molly Sophia +MoonRide303 <130458190+MoonRide303@users.noreply.github.com> MorganRO8 <47795945+MorganRO8@users.noreply.github.com> Murilo Santana Musab Gultekin @@ -524,6 +550,7 @@ Nikolas <127742645+nneubacher@users.noreply.github.com> Nindaleth Nuno OSecret <135510162+OLSecret@users.noreply.github.com> +Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Oleksandr Nikitin Oleksii Maryshchenko Olivier Chafik @@ -533,6 +560,7 @@ PAB Pablo Duboue Pascal Patry Patrice Ferlet +Patrick Peng Paul Tsochantaris Pavel Zloi Pavol Rusnak @@ -549,6 +577,7 @@ Pieter Ouwerkerk Plamen Minev Prashant Vithule <119530321+Vithulep@users.noreply.github.com> Przemysław Pawełczyk +PureJourney Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> Qingyou Meng Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> @@ -564,14 +593,17 @@ Rand Xie Randall Fitzgerald Random Fly Reinforce-II +Rémy O Rémy Oudompheng Ren Xuancheng Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> Reza Kakhki +Reza Rahemtola <49811529+RezaRahemtola@users.noreply.github.com> RhinoDevel Riccardo Orlando Riceball LEE Rich Dougherty +Richard Richard Kiss Richard Roberson Rick G <26732651+TheFlipbook@users.noreply.github.com> @@ -588,6 +620,7 @@ Robert Sung-wook Shin Robey Holderith Robyn Roger Meier +Rohanjames1997 Roland <14355895+rbur0425@users.noreply.github.com> Romain Biessy Romain D <90720+Artefact2@users.noreply.github.com> @@ -610,6 +643,7 @@ Ryan Landay Ryder Wishart Ryuei Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> +SAMI SRHMorris <69468379+SRHMorris@users.noreply.github.com> SXX SakuraUmi @@ -634,6 +668,8 @@ Shane A Shangning Xu <32517059+xushangning@users.noreply.github.com> Shankar Shanshan Shen <467638484@qq.com> +Shelby Jenkins <47464908+ShelbyJenkins@users.noreply.github.com> +Sheldon Robinson Shijie <821898965@qq.com> Shintarou Okada Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> @@ -713,18 +749,24 @@ Victor Nogueira Victor Z. Peng Viet-Anh NGUYEN (Andrew) Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com> +Vitali Lovich +Vivian Vlad Vladimir Vladimir Malyutin +Vladimir Vuksanovic <109677816+vvuksanovic@users.noreply.github.com> Vladimir Zorin VoidIsVoid <343750470@qq.com> Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> +Wagner Bruna Wang Qin <37098874+wangqin0@users.noreply.github.com> Wang Ran (汪然) WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> Weird Constructor +Weizhao Ouyang Welby Seely Wentai Zhang +Wilken Gottwalt <12194808+wgottwalt@users.noreply.github.com> WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> William Tambellini William Tambellini @@ -816,6 +858,8 @@ chaihahaha chiranko <96988916+chiranko@users.noreply.github.com> clibdev <52199778+clibdev@users.noreply.github.com> clyang +cmdr2 +cmdr2 cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> codezjx coezbek @@ -835,6 +879,7 @@ deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> devojony <61173062+devojony@users.noreply.github.com> ditsuke divinity76 +dm4 dm4 dotpy314 <33351922+dotpy314@users.noreply.github.com> drbh @@ -849,6 +894,7 @@ fairydreaming <166155368+fairydreaming@users.noreply.github.com> fengerhu1 <2748250768@qq.com> fj-y-saito <85871716+fj-y-saito@users.noreply.github.com> fraxy-v <65565042+fraxy-v@users.noreply.github.com> +fxzjshm <11426482+fxzjshm@users.noreply.github.com> github-actions[bot] gliptic gn64 @@ -873,6 +919,7 @@ hydai iSma iacore <74560659+iacore@users.noreply.github.com> icppWorld <124377669+icppWorld@users.noreply.github.com> +igardev <49397134+igardev@users.noreply.github.com> igarnier intelmatt <61025942+intelmatt@users.noreply.github.com> iohub @@ -880,6 +927,7 @@ issixx <46835150+issixx@users.noreply.github.com> jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> jameswu2014 <545426914@qq.com> +jason_w jdomke <28772296+jdomke@users.noreply.github.com> jiahao su jiez <373447296@qq.com> @@ -891,6 +939,7 @@ jon-chuang <9093549+jon-chuang@users.noreply.github.com> jp-x-g jukofyork <69222624+jukofyork@users.noreply.github.com> junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> +junchao-zhao <68935141+junchao-loongson@users.noreply.github.com> jwj7140 <32943891+jwj7140@users.noreply.github.com> k.h.lai kaizau @@ -925,6 +974,7 @@ ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> luoyu-intel m3ndax maddes8cht <55592906+maddes8cht@users.noreply.github.com> +magicse mahorozte <41834471+mahorozte@users.noreply.github.com> makomk manikbhandari @@ -935,6 +985,7 @@ matt23654 matteo mdrokz mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> +midnight minarchist mj-shifu <77107165+mj-shifu@users.noreply.github.com> mmyjona @@ -958,10 +1009,12 @@ omahs <73983677+omahs@users.noreply.github.com> oobabooga <112222186+oobabooga@users.noreply.github.com> opparco ostix360 <55257054+ostix360@users.noreply.github.com> +pascal-lc <49066376+pascal-lc@users.noreply.github.com> pculliton peidaqi pengxin99 perserk +petterreinholdtsen piDack <104877312+piDack@users.noreply.github.com> pmysl postmasters @@ -983,6 +1036,7 @@ semidark serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> sharpHL <132747147+sharpHL@users.noreply.github.com> shibe2 +simon886212 <37953122+simon886212@users.noreply.github.com> singularity <12184989+singularity-s0@users.noreply.github.com> sjinzh sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> @@ -1000,10 +1054,12 @@ tarcey tc-mb <157115220+tc-mb@users.noreply.github.com> texmex76 <40733439+texmex76@users.noreply.github.com> thement <40525767+thement@users.noreply.github.com> +theraininsky <76763719+theraininsky@users.noreply.github.com> thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> tjohnman toyer <2042519524@qq.com> tslmy +tv1wnd <55383215+tv1wnd@users.noreply.github.com> ubik2 uint256_t uint256_t @@ -1014,6 +1070,7 @@ valiray <133289098+valiray@users.noreply.github.com> vb vik viric +vmobilis <75476228+vmobilis@users.noreply.github.com> vodkaslime <646329483@qq.com> vvhg1 <94630311+vvhg1@users.noreply.github.com> vxiiduu <73044267+vxiiduu@users.noreply.github.com> @@ -1028,6 +1085,8 @@ wzy <32936898+Freed-Wu@users.noreply.github.com> xaedes xaedes xctan +xiaobing318 <71554036+xiaobing318@users.noreply.github.com> +xiaofei xloem <0xloem@gmail.com> yangli2 ymcki <84055651+ymcki@users.noreply.github.com> @@ -1045,4 +1104,3 @@ zrm 蕭澧邦 <45505768+shou692199@users.noreply.github.com> 谢乃闻 Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> -Jason C.H From 1e2f78a00450593e2dfa458796fcdd9987300dfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Mar 2025 19:08:20 +0200 Subject: [PATCH 19/19] server : add speculative decoding presets for FIM (#12287) --- common/arg.cpp | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 3e549ede0..b96a5678f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2571,5 +2571,43 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--fim-qwen-7b-spec"}, + string_format("use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; + params.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; + params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + + add_opt(common_arg( + {"--fim-qwen-14b-spec"}, + string_format("use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet)"), + [](common_params & params) { + params.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; + params.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; + params.speculative.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.n_gpu_layers = 99; + params.port = 8012; + params.n_gpu_layers = 99; + params.flash_attn = true; + params.n_ubatch = 1024; + params.n_batch = 1024; + params.n_ctx = 0; + params.n_cache_reuse = 256; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + return ctx_arg; }