Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/musa.Dockerfile
#	.github/workflows/build-linux-cross.yml
#	.github/workflows/build-riscv-native.yml
#	.github/workflows/build.yml
#	.github/workflows/docker.yml
#	CODEOWNERS
#	ci/run.sh
#	ggml/CMakeLists.txt
#	ggml/src/ggml-blas/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tools/perplexity/perplexity.cpp
#	tools/server/README.md
This commit is contained in:
Concedo 2025-09-30 00:36:38 +08:00
commit 4f2b951547
48 changed files with 5353 additions and 415 deletions

52
.github/workflows/build-amd.yml vendored Normal file
View file

@ -0,0 +1,52 @@
name: CI (AMD)
on:
workflow_dispatch: # allows manual triggering
push:
branches:
- master
paths: [
'.github/workflows/build-amd.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.comp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
cancel-in-progress: true
jobs:
ggml-ci-x64-amd-vulkan:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
vulkaninfo --summary
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
ggml-ci-x64-amd-rocm:
runs-on: [self-hosted, Linux, X64, AMD]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Test
id: ggml-ci
run: |
amd-smi static
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp

View file

@ -0,0 +1,29 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)
set(CMAKE_SYSTEM_VERSION 1)
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
else()
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
if (DEFINED ENV{RISCV_ROOT_PATH})
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
else()
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
endif()
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
endif()
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")

View file

@ -1616,17 +1616,36 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
);
});
auto recipient_in_role = builder.add_rule("recipient_in_role",
"\"<|start|>assistant\"? \" to=functions.\" ( " +
string_join(tool_rules_recipient_in_role, " | ") + " )"
);
auto recipient_in_channel = builder.add_rule("recipient_in_channel",
channel + " \" to=functions.\" ( " +
string_join(tool_rules_recipient_in_channel, " | ") + " )"
);
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
if (data.grammar_lazy) {
auto recipient_in_role = builder.add_rule("recipient_in_role",
"\"<|start|>assistant\"? \" to=functions.\" ( " +
string_join(tool_rules_recipient_in_role, " | ") + " )"
);
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
} else {
auto not_end = builder.add_rule("not-end",
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
auto analysis = builder.add_rule("analysis",
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
auto commentary = builder.add_rule("commentary",
"\"<|channel|>commentary<|message|>\" ( " + not_end + " )* \"<|end|>\"");
auto recipient_in_role = builder.add_rule("recipient_in_role",
"\" to=functions.\" ( " + string_join(tool_rules_recipient_in_role, " | ") + " )"
);
builder.add_rule("root",
"( " + analysis + " \"<|start|>assistant\" )? " +
"( " + commentary + " \"<|start|>assistant\" )? " +
"( " + recipient_in_role + " | " + recipient_in_channel + " )"
);
}
// Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({

View file

@ -0,0 +1,89 @@
> [!IMPORTANT]
> This build documentation is specific only to RISC-V SpacemiT SOCs.
## Build llama.cpp locally (for riscv64)
1. Prepare Toolchain For RISCV
~~~
wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz
~~~
2. Build
Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`.
```bash
cmake -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_CPU_RISCV64_SPACEMIT=ON \
-DLLAMA_CURL=OFF \
-DGGML_RVV=ON \
-DGGML_RV_ZFH=ON \
-DGGML_RV_ZICBOP=ON \
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \
-DCMAKE_INSTALL_PREFIX=build/installed
cmake --build build --parallel $(nproc) --config Release
pushd build
make install
popd
```
## Simulation
You can use QEMU to perform emulation on non-RISC-V architectures.
1. Download QEMU
~~~
wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz
~~~
2. Run Simulation
After build your llama.cpp, you can run the executable file via QEMU for simulation, for example:
~~~
export QEMU_ROOT_PATH={your QEMU file path}
export RISCV_ROOT_PATH_IME1={your RISC-V compiler path}
${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1
~~~
## Performance
#### Quantization Support For Matrix
~~~
model name : Spacemit(R) X60
isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt
mmu : sv39
uarch : spacemit,x60
mvendorid : 0x710
marchid : 0x8000000058000001
~~~
Q4_0
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26|
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02|
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02|
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02|
Q4_1
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12|
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25|
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16|
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04|
Q4_K
| Model | Size | Params | backend | threads | test | t/s |
| -----------| -------- | ------ | ------- | ------- | ---- |------|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05|
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10|
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04|
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00|

View file

@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
return p;
}
static const char * dl_error() {
return "";
}
#else
using dl_handle = void;
@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
@ -240,7 +249,7 @@ struct ggml_backend_registry {
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
if (!silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str());
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error());
}
return nullptr;
}
@ -531,7 +540,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
dl_handle_ptr handle { dl_load_library(entry) };
if (!handle && !silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error());
}
if (handle) {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");

View file

@ -18,6 +18,10 @@
# include "kleidiai/kleidiai.h"
#endif
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
# include "spacemit/ime.h"
#endif
#if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
@ -45,6 +49,12 @@ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_type
// }
// #endif
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());
}
#endif
#ifdef GGML_USE_CPU_KLEIDIAI
if (ggml_backend_cpu_kleidiai_buffer_type()) {
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,13 @@
#pragma once
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
#pragma once
#include <cstddef>
namespace sqnbitgemm_spacemit_ime {
namespace ime1 {
size_t gemm_kernel_i8i4(size_t blk_len,
const std::byte * quant_a_ptr,
const std::byte * quant_b_data,
const float * quant_b_scale,
const std::byte * quant_b_zp,
float * c_ptr,
size_t count_m,
size_t count_n,
size_t count_k,
size_t block_count_k,
size_t ldc,
const float * bias,
const size_t scale_stride);
void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
} // namespace ime1
} // namespace sqnbitgemm_spacemit_ime

View file

@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
}

View file

@ -2032,7 +2032,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
@ -2040,7 +2040,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
@ -2120,7 +2120,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
return;
}
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
return;
}
@ -3652,9 +3652,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:

View file

@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {
if (ggml_is_quantized(type)) {
return false;
@ -96,8 +96,18 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (src1_ncols > 16) {
return false;
if (mul_mat_id) {
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
return false;
}
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
return false;
}
} else {
if (src1_ncols > 16) {
return false;
}
}
switch (type) {

View file

@ -9,13 +9,13 @@ using namespace ggml_cuda_mma;
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
@ -31,9 +31,20 @@ static __global__ void mul_mat_f(
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = has_ids ? blockIdx.y : 0;
int expert_idx = 0;
int col_base = 0;
const int channel_dst = has_ids ? 0 : blockIdx.y;
if constexpr (has_ids) {
// experts + tiles of ncols_dst are packed in the y dimension
int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
const int nchannels_x = gridDim.y / col_tiles;
const int tile_idx = blockIdx.y / nchannels_x;
expert_idx = blockIdx.y - tile_idx * nchannels_x;
col_base = tile_idx * cols_per_block;
}
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
@ -44,6 +55,14 @@ static __global__ void mul_mat_f(
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
if constexpr (has_ids) {
constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
const int64_t col_offset = col_base;
y += col_offset * stride_col_y * y_stride_scale;
dst += col_offset * stride_col_dst;
ids += col_offset * stride_row_id;
}
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
@ -61,12 +80,17 @@ static __global__ void mul_mat_f(
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
if (threadIdx.x == 0) {
slot_map[j] = -1;
}
if (col_base + j >= ncols_dst_total) {
continue;
}
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
int match = id_row[k*stride_col_id] == expert_idx;
@ -108,7 +132,8 @@ static __global__ void mul_mat_f(
if constexpr (!has_ids) {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
} else {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@ -120,7 +145,8 @@ static __global__ void mul_mat_f(
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
} else {
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
}
@ -183,14 +209,14 @@ static __global__ void mul_mat_f(
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0) {
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
@ -201,20 +227,23 @@ static __global__ void mul_mat_f(
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nchannels_dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
if (ids) {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids(
template <typename T, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
@ -268,49 +298,49 @@ void mul_mat_f_cuda(
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
GGML_ASSERT(ids || ncols_dst <= 16);
switch (ncols_case) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block(
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\

View file

@ -567,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
ctx->debug_graph,
ctx->debug_fusion);
for (int idx = idx_start; idx < idx_end;) {
for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
const int res = ggml_metal_op_encode(ctx_op, idx);
if (res == 0) {
break;
}
idx += res;
idx += res - 1;
}
ggml_metal_op_free(ctx_op);

View file

@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_pipeline_set_smem(res, 8192);
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
return res;
}
@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
ggml_metal_pipeline_set_smem(res, 8192);

View file

@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);

View file

@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;
case GGML_OP_ARANGE:
return true;
case GGML_OP_FLASH_ATTN_EXT:
@ -717,8 +719,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction &&
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
return has_simdgroup_reduction;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:

View file

@ -76,6 +76,7 @@
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
// kernel argument structs
//

View file

@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
}
struct ggml_metal_op {
ggml_metal_op(
ggml_metal_device_t dev,
ggml_metal_cmd_buf_t cmd_buf,
ggml_cgraph * gf,
int idx_start,
int idx_end,
bool use_fusion,
bool use_concurrency,
bool use_capture,
int debug_graph,
int debug_fusion) {
this->dev = dev;
this->lib = ggml_metal_device_get_library(dev);
this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
this->mem_ranges = ggml_mem_ranges_init(debug_graph);
this->idx_start = idx_start;
this->idx_end = idx_end;
this->use_fusion = use_fusion;
this->use_concurrency = use_concurrency;
this->use_capture = use_capture;
this->debug_graph = debug_graph;
this->debug_fusion = debug_fusion;
this->gf = gf;
idxs.reserve(gf->n_nodes);
// filter empty nodes
// TODO: this can be removed when the allocator starts filtering them earlier
// https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
for (int i = idx_start; i < idx_end; i++) {
if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
idxs.push_back(i);
}
}
}
~ggml_metal_op() {
ggml_metal_encoder_end_encoding(this->enc);
ggml_metal_encoder_free(this->enc);
ggml_mem_ranges_free(this->mem_ranges);
}
int n_nodes() const {
return idxs.size();
}
ggml_tensor * node(int i) const {
assert(i >= 0 && i < (int) idxs.size());
return ggml_graph_node(gf, idxs[i]);
}
bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
assert(use_fusion);
assert(i0 >= 0 && i0 < n_nodes());
if (i0 + n_ops > n_nodes()) {
return false;
}
return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
}
ggml_metal_device_t dev;
ggml_metal_library_t lib;
ggml_metal_encoder_t enc;
ggml_mem_ranges_t mem_ranges;
ggml_cgraph * gf;
int idx_start;
int idx_end;
bool use_fusion;
bool use_concurrency;
bool use_capture;
int debug_graph;
int debug_fusion;
private:
ggml_cgraph * gf;
int idx_start;
int idx_end;
// non-empty node indices
std::vector<int> idxs;
};
ggml_metal_op_t ggml_metal_op_init(
@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init(
bool use_capture,
int debug_graph,
int debug_fusion) {
ggml_metal_op_t res = new ggml_metal_op();
*res = {
/*.dev =*/ dev,
/*.lib =*/ ggml_metal_device_get_library(dev),
/*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency),
/*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph),
/*.gf =*/ gf,
/*.idx_start =*/ idx_start,
/*.idx_end =*/ idx_end,
/*.use_fusion =*/ use_fusion,
/*.use_concurrency =*/ use_concurrency,
/*.use_capture =*/ use_capture,
/*.debug_graph =*/ debug_graph,
/*.debug_fusion =*/ debug_fusion,
};
ggml_metal_op_t res = new ggml_metal_op(
dev,
cmd_buf,
gf,
idx_start,
idx_end,
use_fusion,
use_concurrency,
use_capture,
debug_graph,
debug_fusion);
return res;
}
void ggml_metal_op_free(ggml_metal_op_t ctx) {
ggml_metal_encoder_end_encoding(ctx->enc);
ggml_metal_encoder_free(ctx->enc);
ggml_mem_ranges_free(ctx->mem_ranges);
delete ctx;
}
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
return ctx->n_nodes();
}
static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
if (!ctx->mem_ranges) {
return true;
@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor
}
static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
struct ggml_cgraph * gf = ctx->gf;
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
struct ggml_tensor * node = nodes[0];
struct ggml_tensor * node = ctx->node(idx);
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
case GGML_OP_PERMUTE:
{
// noop -> next node
if (ctx->debug_graph > 0) {
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
}
} return 1;
default:
{
@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
// update the mem ranges in the encoding context
for (int i = 0; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) {
if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
}
}
@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
if (ctx->use_capture) {
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx)));
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
}
int res = ggml_metal_op_encode_impl(ctx, idx);
if (idx + res > ctx->idx_end) {
if (idx + res > ctx->n_nodes()) {
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
"https://github.com/ggml-org/llama.cpp/pull/14849");
}
@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1476,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm &&
op->src[1]->type == GGML_TYPE_F32 &&
ne00 % 32 == 0 && ne00 >= 64 &&
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
@ -1589,8 +1631,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
}
int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1612,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
@ -1631,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
if (props_dev->has_simdgroup_mm &&
ne00 % 32 == 0 && ne00 >= 64 &&
(ne21 >= ne21_mm_id_min)) {
GGML_ASSERT(ne00 % 4 == 0);
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}
// extra buffers for intermediate id mapping
ggml_metal_buffer_id bid_tpe = bid_dst;
@ -1687,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
ggml_metal_kargs_mul_mm_id args = {
/*.ne00 =*/ ne00,
@ -1783,8 +1818,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -1856,8 +1890,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2176,16 +2209,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
@ -2258,22 +2286,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
// across splits. idx_end indicates the last node in the current split
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
ggml_tensor * f0 = ctx->node(idx + n_fuse);
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
if (f0 != f1->src[0]) {
break;
}
// b[0] === b[1] === ...
if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) {
if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
break;
}
// only fuse ops if src1 is in the same Metal buffer
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
if (bid_fuse.metal != bid_src1.metal) {
break;
}
@ -2309,10 +2340,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
break;
@ -2344,8 +2375,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2393,8 +2423,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2445,20 +2474,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
@ -2499,38 +2523,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
fops[1] = GGML_OP_MUL;
fops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
ggml_tensor * f0 = ctx->node(idx + n_fuse);
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
if (f0 != f1->src[0]) {
break;
}
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
if (f1->src[1]->ne[0] != op->ne[0]) {
break;
}
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
if (!ggml_is_contiguous_rows(f1->src[1])) {
break;
}
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
if (f1->type != GGML_TYPE_F32) {
break;
}
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
//ctx->fuse_cnt[f1->op]++;
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
}
++n_fuse;
@ -2546,10 +2573,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
break;
@ -2585,8 +2612,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2681,8 +2707,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2752,8 +2777,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2798,8 +2822,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2852,8 +2875,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2897,8 +2919,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2944,8 +2965,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -2985,8 +3005,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -3020,8 +3039,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -3060,8 +3078,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@ -3103,8 +3120,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

View file

@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init(
void ggml_metal_op_free(ggml_metal_op_t ctx);
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
//

View file

@ -33,6 +33,7 @@ using namespace metal;
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
@ -7856,6 +7857,9 @@ kernel void kernel_set_rows_f(
}
}
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
@ -7868,7 +7872,7 @@ kernel void kernel_set_rows_f(
#define SG_MAT_ROW 8
// each block_q contains 16*nl weights
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
@ -7879,8 +7883,8 @@ kernel void kernel_mul_mm(
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
@ -7894,8 +7898,9 @@ kernel void kernel_mul_mm(
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2];
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
@ -7913,27 +7918,45 @@ kernel void kernel_mul_mm(
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const float * y = (device const float *)(src1
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ args.nb10*iy);
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
T4x4 temp_a;
dequantize_func(x, il, temp_a);
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
#pragma unroll(16)
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
}
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
}
} else {
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@ -7942,8 +7965,8 @@ kernel void kernel_mul_mm(
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
#pragma unroll(4)
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
@ -7971,7 +7994,8 @@ kernel void kernel_mul_mm(
}
}
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
@ -8076,7 +8100,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
@ -8090,8 +8114,8 @@ kernel void kernel_mul_mm_id(
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
@ -8114,8 +8138,9 @@ kernel void kernel_mul_mm_id(
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_T8x8 ma[4];
simdgroup_half8x8 mb[2];
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
@ -8136,27 +8161,45 @@ kernel void kernel_mul_mm_id(
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const float * y = (device const float *)(src1
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ args.nb10*iy);
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
T4x4 temp_a;
dequantize_func(x, il, temp_a);
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
#pragma unroll(16)
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
}
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
}
} else {
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
@ -8165,8 +8208,8 @@ kernel void kernel_mul_mm_id(
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
#pragma unroll(4)
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
@ -8299,66 +8342,117 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// matrix-vector multiplication

View file

@ -19,12 +19,14 @@
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
// to avoid conflicts with applications or other libraries who might use it.
namespace vk::detail { class DispatchLoaderDynamic; }
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
#include <vulkan/vulkan.hpp>
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
#include <algorithm>
#include <cmath>
#include <iomanip>
@ -422,6 +424,8 @@ struct vk_device_struct {
bool subgroup_ballot;
bool subgroup_clustered;
bool multi_add;
bool shader_int64;
bool buffer_device_address;
bool add_rms_fusion;
uint32_t partials_binding_alignment;
@ -669,6 +673,7 @@ struct vk_buffer_struct {
vk::MemoryPropertyFlags memory_property_flags;
void * ptr;
size_t size = 0;
vk::DeviceAddress bda_addr {};
vk_device device;
@ -1001,6 +1006,7 @@ struct vk_op_argsort_push_constants {
};
struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
uint32_t IC;
uint32_t IW; uint32_t IH;
@ -1014,6 +1020,7 @@ struct vk_op_im2col_push_constants {
};
struct vk_op_im2col_3d_push_constants {
uint64_t dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
@ -2026,10 +2033,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
return buf;
}
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
vk::MemoryAllocateFlags mem_flags {};
if (device->buffer_device_address) {
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
}
vk::BufferCreateInfo buffer_create_info{
vk::BufferCreateFlags(),
size,
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
usage_flags,
vk::SharingMode::eExclusive,
0,
nullptr,
@ -2041,6 +2055,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
@ -2052,7 +2068,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->memory_property_flags = req_flags;
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
break;
} catch (const vk::SystemError& e) {
// loop and retry
@ -2080,6 +2096,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->device = device;
buf->size = size;
if (device->buffer_device_address) {
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
buf->bda_addr = device->device.getBufferAddress(addressInfo);
}
#ifdef GGML_VULKAN_MEMORY_DEBUG
device->memory_logger->log_allocation(buf, size);
#endif
@ -3546,14 +3567,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
}
if (device->shader_int64 && device->buffer_device_address) {
IM2COL(_bda)
} else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
IM2COL()
}
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@ -4045,6 +4072,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->vendor_id != VK_VENDOR_ID_INTEL &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
device->shader_int64 = device_features2.features.shaderInt64;
device->buffer_device_address = vk12_features.bufferDeviceAddress;
if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@ -4538,6 +4568,12 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
return ggml_vk_default_dispatcher_instance;
}
static void ggml_vk_instance_init() {
if (vk_instance_initialized) {
return;
@ -4545,7 +4581,7 @@ static void ggml_vk_instance_init() {
VK_LOG_DEBUG("ggml_vk_instance_init()");
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr);
ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);
uint32_t api_version = vk::enumerateInstanceVersion();
@ -5683,8 +5719,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
ggml_vk_queue_command_pools_cleanup(dst->device);
}
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
if (disable_split_k) {
return 1;
}
uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
@ -6009,7 +6049,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
@ -6027,8 +6067,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
const uint32_t stride_batch_d = stride_d*ne21;
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
@ -6097,7 +6138,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01;
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@ -6256,13 +6297,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}
// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
); // NOLINT
@ -6740,9 +6784,36 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
// where the M dimension is very large.
// Split_k doesn't work with M splitting.
const size_t nbytes = ggml_nbytes(src0);
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
if (needs_split) {
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
uint32_t m_offset = 0;
while (m_offset < dst->ne[0]) {
const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
ggml_tensor dst2 = *dst;
ggml_tensor src02 = *src0;
dst2.view_src = dst->view_src ? dst->view_src : dst;
src02.view_src = src0->view_src ? src0->view_src : src0;
dst2.view_offs += m_offset * dst->nb[0];
src02.view_offs += m_offset * src0->nb[1];
dst2.ne[0] = cur_M_size;
src02.ne[1] = cur_M_size;
ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun);
m_offset += cur_M_size;
}
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
// detect 0213 permutation, and batch size of 1
src0->nb[0] <= src0->nb[2] &&
src0->nb[2] <= src0->nb[1] &&
@ -6762,7 +6833,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
} else {
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun);
}
}
@ -8622,6 +8693,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
d_sz = 1;
}
// im2col uses only src1 and dst buffers
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
@ -9473,7 +9548,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t pelements = OW * KW * KH;
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
@ -9509,8 +9590,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
const int64_t OH = ne2;
const int64_t OW = ne1;
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
vk_op_im2col_3d_push_constants pc {};
pc.dst_addr = dst_addr;
pc.nb10 = nb10 / ggml_type_size(src1->type);
pc.nb11 = nb11 / ggml_type_size(src1->type);
pc.nb12 = nb12 / ggml_type_size(src1->type);
@ -10697,10 +10784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
ctx->semaphore_idx = 0;
const ggml_tensor * src0 = node->src[0];
const ggml_tensor * src1 = node->src[1];
const ggml_tensor * src2 = node->src[2];
const ggml_tensor * src3 = node->src[3];
ggml_tensor * src0 = node->src[0];
ggml_tensor * src1 = node->src[1];
ggml_tensor * src2 = node->src[2];
ggml_tensor * src3 = node->src[3];
switch (node->op) {
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor

View file

@ -117,6 +117,9 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
@ -155,7 +158,11 @@ void main() {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
if (!KV_bounds_check || j * Bc + c < KV) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
@ -172,8 +179,11 @@ void main() {
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = Sf[r][0];
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
@ -190,6 +200,9 @@ void main() {
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
@ -203,6 +216,9 @@ void main() {
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);

View file

@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
@ -65,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif

View file

@ -152,14 +152,17 @@ void main() {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
}
@ -202,7 +205,9 @@ void main() {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
if (!KV_bounds_check || j * Bc + c < KV) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
}
barrier();
@ -210,8 +215,11 @@ void main() {
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
@ -233,6 +241,9 @@ void main() {
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);

View file

@ -5,8 +5,11 @@
#include "rte.comp"
#include "types.comp"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
int d0; int d1;
} p;
#include "types.comp"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
const uint NUM_ITER = 512 / BLOCK_SIZE;
@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint gidx = gl_GlobalInvocationID.x;
@ -38,7 +43,7 @@ void main() {
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
@ -50,7 +55,7 @@ void main() {
uint current_ix = rem % p.OW;
A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
@ -66,7 +71,7 @@ void main() {
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
@ -89,7 +94,11 @@ void main() {
continue;
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
#endif
}
}

View file

@ -6,8 +6,11 @@
#include "rte.comp"
#include "types.comp"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
uint32_t misalign_offsets;
} p;
#include "types.comp"
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint32_t i = gl_GlobalInvocationID.x;
@ -100,13 +105,22 @@ void main() {
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
if (iih >= IH || iiw >= IW || iid >= ID) {
dst_addr.d = D_TYPE(0.0f);
} else {
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#else
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#endif
}
}

View file

@ -265,7 +265,6 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
#if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@ -281,6 +280,8 @@ void main() {
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)

View file

@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
return uintBitsToFloat(bits);
}
#if BDA
#extension GL_EXT_buffer_reference : enable
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
#define BDA_STORAGE_T uint64_t
#define BDA_OFFSET_T uint64_t
#else
#define BDA_STORAGE_T uvec2
#define BDA_OFFSET_T uint
#endif
#endif // !defined(GGML_TYPES_COMP)

View file

@ -789,13 +789,15 @@ void process_shaders() {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {
std::string bda_str = bda ? "_bda" : "";
std::string bda_def = bda ? "1" : "0";
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
}
}
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

View file

@ -3703,6 +3703,7 @@ struct ggml_tensor * ggml_set_rows(
result->op = GGML_OP_SET_ROWS;
result->src[0] = b;
result->src[1] = c;
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
return result;
}

View file

@ -708,6 +708,10 @@ int main(int argc, char ** argv) {
embd.push_back(id);
if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
assistant_ss << common_token_to_piece(ctx, id, false);
}
// echo this to console
input_echo = true;
@ -825,11 +829,7 @@ int main(int argc, char ** argv) {
}
}
// if current token is not EOG, we add it to current assistant message
if (params.conversation_mode && !waiting_for_first_input) {
const auto id = common_sampler_last(smpl);
assistant_ss << common_token_to_piece(ctx, id, false);
if (!prompt.empty()) {
prompt.clear();
is_interacting = false;

Binary file not shown.

View file

@ -39,6 +39,7 @@
--sidebar-ring: oklch(0.708 0 0);
--code-background: oklch(0.225 0 0);
--code-foreground: oklch(0.875 0 0);
--layer-popover: 1000000;
}
.dark {

View file

@ -4,7 +4,6 @@
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte';
interface Props {
message: DatabaseMessage;
role: 'user' | 'assistant';
justify: 'start' | 'end';
actionsPosition: 'left' | 'right';
@ -29,7 +28,6 @@
actionsPosition,
deletionInfo,
justify,
message,
onCopy,
onEdit,
onConfirmDelete,
@ -49,26 +47,17 @@
</script>
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-{justify}">
<div
class="hidden items-center text-xs text-muted-foreground transition-opacity md:flex md:group-hover:opacity-0"
>
{new Date(message.timestamp).toLocaleTimeString(undefined, {
hour: '2-digit',
minute: '2-digit'
})}
</div>
<div
class="absolute top-0 {actionsPosition === 'left'
? 'left-0'
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity md:opacity-0 md:group-hover:opacity-100"
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity"
>
{#if siblingInfo && siblingInfo.totalSiblings > 1}
<ChatMessageBranchingControls {siblingInfo} {onNavigateToSibling} />
{/if}
<div
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-100 transition-all duration-150 md:pointer-events-none md:opacity-0 md:group-hover:pointer-events-auto md:group-hover:opacity-100"
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-100 transition-all duration-150"
>
<ActionButton icon={Copy} tooltip="Copy" onclick={onCopy} />

View file

@ -138,7 +138,6 @@
{#if message.timestamp && !isEditing}
<ChatMessageActions
{message}
role="assistant"
justify="start"
actionsPosition="left"

View file

@ -135,7 +135,6 @@
actionsPosition="right"
{deletionInfo}
justify="end"
{message}
{onConfirmDelete}
{onCopy}
{onDelete}

View file

@ -362,7 +362,8 @@
<Dialog.Root {open} onOpenChange={handleClose}>
<Dialog.Content
class="z-999999 flex h-[100vh] flex-col gap-0 rounded-none p-0 md:h-[64vh] md:rounded-lg"
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
style="max-width: 48rem;"
>
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
@ -441,7 +442,7 @@
</div>
</div>
<ScrollArea class="max-h-[calc(100vh-13.5rem)] flex-1">
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
<div class="space-y-6 p-4 md:p-6">
<div>
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
@ -456,7 +457,6 @@
{localConfig}
onConfigChange={handleConfigChange}
onThemeChange={handleThemeChange}
isMobile={false}
/>
</div>
</div>

View file

@ -13,10 +13,9 @@
localConfig: SettingsConfigType;
onConfigChange: (key: string, value: string | boolean) => void;
onThemeChange?: (theme: string) => void;
isMobile?: boolean;
}
let { fields, localConfig, onConfigChange, onThemeChange, isMobile = false }: Props = $props();
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
</script>
{#each fields as field (field.key)}
@ -28,10 +27,10 @@
<Input
id={field.key}
value={String(localConfig[field.key] || '')}
value={String(localConfig[field.key] ?? '')}
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
class={isMobile ? 'w-full' : 'max-w-md'}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
class="w-full md:max-w-md"
/>
{#if field.help || SETTING_CONFIG_INFO[field.key]}
<p class="mt-1 text-xs text-muted-foreground">
@ -45,10 +44,10 @@
<Textarea
id={field.key}
value={String(localConfig[field.key] || '')}
value={String(localConfig[field.key] ?? '')}
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
class={isMobile ? 'min-h-[100px] w-full' : 'min-h-[100px] max-w-2xl'}
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
class="min-h-[100px] w-full md:max-w-2xl"
/>
{#if field.help || SETTING_CONFIG_INFO[field.key]}
<p class="mt-1 text-xs text-muted-foreground">
@ -76,7 +75,7 @@
}
}}
>
<Select.Trigger class={isMobile ? 'w-full' : 'max-w-md'}>
<Select.Trigger class="w-full md:w-auto md:max-w-md">
<div class="flex items-center gap-2">
{#if selectedOption?.icon}
{@const IconComponent = selectedOption.icon}
@ -109,6 +108,7 @@
{/if}
{:else if field.type === 'checkbox'}
{@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()}
<div class="flex items-start space-x-3">
<Checkbox
id={field.key}

View file

@ -1,5 +1,6 @@
<script lang="ts">
import { Button } from '$lib/components/ui/button';
import * as AlertDialog from '$lib/components/ui/alert-dialog';
interface Props {
onReset?: () => void;
@ -8,8 +9,15 @@
let { onReset, onSave }: Props = $props();
function handleReset() {
let showResetDialog = $state(false);
function handleResetClick() {
showResetDialog = true;
}
function handleConfirmReset() {
onReset?.();
showResetDialog = false;
}
function handleSave() {
@ -18,7 +26,23 @@
</script>
<div class="flex justify-between border-t border-border/30 p-6">
<Button variant="outline" onclick={handleReset}>Reset to default</Button>
<Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
<Button onclick={handleSave}>Save settings</Button>
</div>
<AlertDialog.Root bind:open={showResetDialog}>
<AlertDialog.Content>
<AlertDialog.Header>
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
<AlertDialog.Description>
Are you sure you want to reset all settings to their default values? This action cannot be
undone and will permanently remove all your custom configurations.
</AlertDialog.Description>
</AlertDialog.Header>
<AlertDialog.Footer>
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
<AlertDialog.Action onclick={handleConfirmReset}>Reset to Default</AlertDialog.Action>
</AlertDialog.Footer>
</AlertDialog.Content>
</AlertDialog.Root>

View file

@ -87,7 +87,7 @@
<Sidebar.GroupContent>
<Sidebar.Menu>
{#each filteredConversations as conversation (conversation.id)}
<Sidebar.MenuItem class="mb-1" onclick={handleMobileSidebarItemClick}>
<Sidebar.MenuItem class="mb-1">
<ChatSidebarConversationItem
conversation={{
id: conversation.id,
@ -95,6 +95,7 @@
lastModified: conversation.lastModified,
currNode: conversation.currNode
}}
{handleMobileSidebarItemClick}
isActive={currentChatId === conversation.id}
onSelect={selectConversation}
onEdit={editConversation}

View file

@ -8,6 +8,7 @@
interface Props {
isActive?: boolean;
conversation: DatabaseConversation;
handleMobileSidebarItemClick?: () => void;
onDelete?: (id: string) => void;
onEdit?: (id: string, name: string) => void;
onSelect?: (id: string) => void;
@ -16,6 +17,7 @@
let {
conversation,
handleMobileSidebarItemClick,
onDelete,
onEdit,
onSelect,
@ -47,6 +49,7 @@
function handleConfirmEdit() {
if (!editedName.trim()) return;
showEditDialog = false;
onEdit?.(conversation.id, editedName);
}
@ -85,7 +88,12 @@
: ''}"
onclick={handleSelect}
>
<div class="text flex min-w-0 flex-1 items-center space-x-3">
<!-- svelte-ignore a11y_click_events_have_key_events -->
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="text flex min-w-0 flex-1 items-center space-x-3"
onclick={handleMobileSidebarItemClick}
>
<div class="min-w-0 flex-1">
<p class="truncate text-sm font-medium">{conversation.name}</p>
@ -178,5 +186,10 @@
&:is(:hover) :global([data-slot='dropdown-menu-trigger']) {
opacity: 1;
}
@media (max-width: 768px) {
:global([data-slot='dropdown-menu-trigger']) {
opacity: 1 !important;
}
}
}
</style>

View file

@ -37,6 +37,7 @@
<DropdownMenu.Root bind:open>
<DropdownMenu.Trigger
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
onclick={(e) => e.stopPropagation()}
>
{#if triggerTooltip}
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
@ -53,7 +54,7 @@
{/if}
</DropdownMenu.Trigger>
<DropdownMenu.Content {align} class="z-999 w-48">
<DropdownMenu.Content {align} class="z-[999999] w-48">
{#each actions as action, index (action.label)}
{#if action.separator && index > 0}
<DropdownMenu.Separator />

View file

@ -19,7 +19,15 @@
bind:ref
data-slot="alert-dialog-content"
class={cn(
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
'fixed z-[999999] grid w-full gap-4 border bg-background p-6 shadow-lg duration-200',
// Mobile: Bottom sheet behavior
'right-0 bottom-0 left-0 max-h-[100dvh] translate-x-0 translate-y-0 overflow-y-auto rounded-t-lg',
'data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:slide-out-to-bottom-full',
'data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:slide-in-from-bottom-full',
// Desktop: Centered dialog behavior
'sm:top-[50%] sm:right-auto sm:bottom-auto sm:left-[50%] sm:max-h-[100vh] sm:max-w-lg sm:translate-x-[-50%] sm:translate-y-[-50%] sm:rounded-lg',
'sm:data-[state=closed]:slide-out-to-bottom-0 sm:data-[state=closed]:zoom-out-95',
'sm:data-[state=open]:slide-in-from-bottom-0 sm:data-[state=open]:zoom-in-95',
className
)}
{...restProps}

View file

@ -13,7 +13,10 @@
<div
bind:this={ref}
data-slot="alert-dialog-footer"
class={cn('flex flex-col-reverse gap-2 sm:flex-row sm:justify-end', className)}
class={cn(
'mt-6 flex flex-row gap-2 sm:mt-0 sm:justify-end [&>*]:flex-1 sm:[&>*]:flex-none',
className
)}
{...restProps}
>
{@render children?.()}

View file

@ -25,7 +25,7 @@
bind:ref
data-slot="dialog-content"
class={cn(
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
`fixed top-[50%] left-[50%] z-50 grid max-h-[100dvh] w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg md:max-h-[100vh]`,
className
)}
{...restProps}

View file

@ -1,4 +1,5 @@
<script lang="ts">
import { onDestroy, onMount } from 'svelte';
import { Select as SelectPrimitive } from 'bits-ui';
import SelectScrollUpButton from './select-scroll-up-button.svelte';
import SelectScrollDownButton from './select-scroll-down-button.svelte';
@ -14,6 +15,76 @@
}: WithoutChild<SelectPrimitive.ContentProps> & {
portalProps?: SelectPrimitive.PortalProps;
} = $props();
let cleanupInternalListeners: (() => void) | undefined;
onMount(() => {
const listenerOptions: AddEventListenerOptions = { passive: false };
const blockOutsideWheel = (event: WheelEvent) => {
if (!ref) {
return;
}
const target = event.target as Node | null;
if (!target || !ref.contains(target)) {
event.preventDefault();
event.stopPropagation();
}
};
const blockOutsideTouchMove = (event: TouchEvent) => {
if (!ref) {
return;
}
const target = event.target as Node | null;
if (!target || !ref.contains(target)) {
event.preventDefault();
event.stopPropagation();
}
};
document.addEventListener('wheel', blockOutsideWheel, listenerOptions);
document.addEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
return () => {
document.removeEventListener('wheel', blockOutsideWheel, listenerOptions);
document.removeEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
};
});
$effect(() => {
const element = ref;
cleanupInternalListeners?.();
if (!element) {
return;
}
const stopWheelPropagation = (event: WheelEvent) => {
event.stopPropagation();
};
const stopTouchPropagation = (event: TouchEvent) => {
event.stopPropagation();
};
element.addEventListener('wheel', stopWheelPropagation);
element.addEventListener('touchmove', stopTouchPropagation);
cleanupInternalListeners = () => {
element.removeEventListener('wheel', stopWheelPropagation);
element.removeEventListener('touchmove', stopTouchPropagation);
};
});
onDestroy(() => {
cleanupInternalListeners?.();
});
</script>
<SelectPrimitive.Portal {...portalProps}>
@ -22,7 +93,7 @@
{sideOffset}
data-slot="select-content"
class={cn(
'relative z-50 max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
'relative z-[var(--layer-popover,1000000)] max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
className
)}
{...restProps}