mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-09 19:46:11 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/musa.Dockerfile # .github/workflows/build-linux-cross.yml # .github/workflows/build-riscv-native.yml # .github/workflows/build.yml # .github/workflows/docker.yml # CODEOWNERS # ci/run.sh # ggml/CMakeLists.txt # ggml/src/ggml-blas/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # scripts/sync-ggml.last # tests/test-backend-ops.cpp # tools/perplexity/perplexity.cpp # tools/server/README.md
This commit is contained in:
commit
4f2b951547
48 changed files with 5353 additions and 415 deletions
52
.github/workflows/build-amd.yml
vendored
Normal file
52
.github/workflows/build-amd.yml
vendored
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
29
cmake/riscv64-spacemit-linux-gnu-gcc.cmake
Normal file
29
cmake/riscv64-spacemit-linux-gnu-gcc.cmake
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
set(CMAKE_SYSTEM_NAME Linux)
|
||||
set(CMAKE_SYSTEM_PROCESSOR riscv64)
|
||||
set(CMAKE_SYSTEM_VERSION 1)
|
||||
|
||||
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
|
||||
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
|
||||
else()
|
||||
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
|
||||
if (DEFINED ENV{RISCV_ROOT_PATH})
|
||||
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
|
||||
else()
|
||||
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
|
||||
endif()
|
||||
|
||||
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
|
||||
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
|
||||
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
|
||||
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
|
||||
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
|
||||
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
|
||||
endif()
|
||||
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")
|
||||
|
|
@ -1616,17 +1616,36 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
);
|
||||
});
|
||||
|
||||
auto recipient_in_role = builder.add_rule("recipient_in_role",
|
||||
"\"<|start|>assistant\"? \" to=functions.\" ( " +
|
||||
string_join(tool_rules_recipient_in_role, " | ") + " )"
|
||||
);
|
||||
|
||||
auto recipient_in_channel = builder.add_rule("recipient_in_channel",
|
||||
channel + " \" to=functions.\" ( " +
|
||||
string_join(tool_rules_recipient_in_channel, " | ") + " )"
|
||||
);
|
||||
|
||||
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
|
||||
if (data.grammar_lazy) {
|
||||
auto recipient_in_role = builder.add_rule("recipient_in_role",
|
||||
"\"<|start|>assistant\"? \" to=functions.\" ( " +
|
||||
string_join(tool_rules_recipient_in_role, " | ") + " )"
|
||||
);
|
||||
|
||||
builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);
|
||||
} else {
|
||||
auto not_end = builder.add_rule("not-end",
|
||||
"[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]");
|
||||
auto analysis = builder.add_rule("analysis",
|
||||
"\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
auto commentary = builder.add_rule("commentary",
|
||||
"\"<|channel|>commentary<|message|>\" ( " + not_end + " )* \"<|end|>\"");
|
||||
|
||||
auto recipient_in_role = builder.add_rule("recipient_in_role",
|
||||
"\" to=functions.\" ( " + string_join(tool_rules_recipient_in_role, " | ") + " )"
|
||||
);
|
||||
|
||||
builder.add_rule("root",
|
||||
"( " + analysis + " \"<|start|>assistant\" )? " +
|
||||
"( " + commentary + " \"<|start|>assistant\" )? " +
|
||||
"( " + recipient_in_role + " | " + recipient_in_channel + " )"
|
||||
);
|
||||
}
|
||||
|
||||
// Trigger on tool calls that appear in the commentary channel
|
||||
data.grammar_triggers.push_back({
|
||||
|
|
|
|||
89
docs/build-riscv64-spacemit.md
Normal file
89
docs/build-riscv64-spacemit.md
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
> [!IMPORTANT]
|
||||
> This build documentation is specific only to RISC-V SpacemiT SOCs.
|
||||
|
||||
## Build llama.cpp locally (for riscv64)
|
||||
|
||||
1. Prepare Toolchain For RISCV
|
||||
~~~
|
||||
wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz
|
||||
~~~
|
||||
|
||||
2. Build
|
||||
Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`.
|
||||
```bash
|
||||
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_CPU_RISCV64_SPACEMIT=ON \
|
||||
-DLLAMA_CURL=OFF \
|
||||
-DGGML_RVV=ON \
|
||||
-DGGML_RV_ZFH=ON \
|
||||
-DGGML_RV_ZICBOP=ON \
|
||||
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
|
||||
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \
|
||||
-DCMAKE_INSTALL_PREFIX=build/installed
|
||||
|
||||
cmake --build build --parallel $(nproc) --config Release
|
||||
|
||||
pushd build
|
||||
make install
|
||||
popd
|
||||
```
|
||||
|
||||
## Simulation
|
||||
You can use QEMU to perform emulation on non-RISC-V architectures.
|
||||
|
||||
1. Download QEMU
|
||||
~~~
|
||||
wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz
|
||||
~~~
|
||||
|
||||
2. Run Simulation
|
||||
After build your llama.cpp, you can run the executable file via QEMU for simulation, for example:
|
||||
~~~
|
||||
export QEMU_ROOT_PATH={your QEMU file path}
|
||||
export RISCV_ROOT_PATH_IME1={your RISC-V compiler path}
|
||||
|
||||
${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1
|
||||
~~~
|
||||
## Performance
|
||||
#### Quantization Support For Matrix
|
||||
~~~
|
||||
model name : Spacemit(R) X60
|
||||
isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt
|
||||
mmu : sv39
|
||||
uarch : spacemit,x60
|
||||
mvendorid : 0x710
|
||||
marchid : 0x8000000058000001
|
||||
~~~
|
||||
|
||||
Q4_0
|
||||
| Model | Size | Params | backend | threads | test | t/s |
|
||||
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26|
|
||||
Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01|
|
||||
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02|
|
||||
Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06|
|
||||
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02|
|
||||
Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02|
|
||||
|
||||
Q4_1
|
||||
| Model | Size | Params | backend | threads | test | t/s |
|
||||
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12|
|
||||
Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01|
|
||||
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25|
|
||||
Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15|
|
||||
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16|
|
||||
Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04|
|
||||
|
||||
|
||||
Q4_K
|
||||
| Model | Size | Params | backend | threads | test | t/s |
|
||||
| -----------| -------- | ------ | ------- | ------- | ---- |------|
|
||||
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05|
|
||||
Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04|
|
||||
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10|
|
||||
Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08|
|
||||
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04|
|
||||
Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00|
|
||||
|
|
@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
|||
return p;
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
|
@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
|
|||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
|
@ -240,7 +249,7 @@ struct ggml_backend_registry {
|
|||
dl_handle_ptr handle { dl_load_library(path) };
|
||||
if (!handle) {
|
||||
if (!silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -531,7 +540,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
||||
dl_handle_ptr handle { dl_load_library(entry) };
|
||||
if (!handle && !silent) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str());
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error());
|
||||
}
|
||||
if (handle) {
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@
|
|||
# include "kleidiai/kleidiai.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
|
||||
# include "spacemit/ime.h"
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
|
|
@ -45,6 +49,12 @@ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_type
|
|||
// }
|
||||
// #endif
|
||||
|
||||
#ifdef GGML_USE_CPU_RISCV64_SPACEMIT
|
||||
if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type());
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CPU_KLEIDIAI
|
||||
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
||||
|
|
|
|||
1024
ggml/src/ggml-cpu/spacemit/ime.cpp
Normal file
1024
ggml/src/ggml-cpu/spacemit/ime.cpp
Normal file
File diff suppressed because it is too large
Load diff
13
ggml/src/ggml-cpu/spacemit/ime.h
Normal file
13
ggml/src/ggml-cpu/spacemit/ime.h
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
3196
ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
Normal file
3196
ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp
Normal file
File diff suppressed because it is too large
Load diff
26
ggml/src/ggml-cpu/spacemit/ime_kernels.h
Normal file
26
ggml/src/ggml-cpu/spacemit/ime_kernels.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace sqnbitgemm_spacemit_ime {
|
||||
namespace ime1 {
|
||||
size_t gemm_kernel_i8i4(size_t blk_len,
|
||||
const std::byte * quant_a_ptr,
|
||||
const std::byte * quant_b_data,
|
||||
const float * quant_b_scale,
|
||||
const std::byte * quant_b_zp,
|
||||
float * c_ptr,
|
||||
size_t count_m,
|
||||
size_t count_n,
|
||||
size_t count_k,
|
||||
size_t block_count_k,
|
||||
size_t ldc,
|
||||
const float * bias,
|
||||
const size_t scale_stride);
|
||||
|
||||
void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
|
||||
|
||||
void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr);
|
||||
|
||||
} // namespace ime1
|
||||
} // namespace sqnbitgemm_spacemit_ime
|
||||
|
|
@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co
|
|||
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
||||
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
|
||||
ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs);
|
||||
|
||||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2032,7 +2032,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
|
|
@ -2040,7 +2040,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
||||
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false);
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
|
|
@ -2120,7 +2120,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
|
||||
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) {
|
||||
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
|
|
@ -3652,9 +3652,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) {
|
||||
|
||||
if (ggml_is_quantized(type)) {
|
||||
return false;
|
||||
|
|
@ -96,8 +96,18 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
||||
return false;
|
||||
}
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
|
||||
if (mul_mat_id) {
|
||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
||||
return false;
|
||||
}
|
||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (src1_ncols > 16) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ using namespace ggml_cuda_mma;
|
|||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int stride_col_id, const int stride_row_id,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
|
|
@ -31,9 +31,20 @@ static __global__ void mul_mat_f(
|
|||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = has_ids ? blockIdx.y : 0;
|
||||
int expert_idx = 0;
|
||||
int col_base = 0;
|
||||
|
||||
const int channel_dst = has_ids ? 0 : blockIdx.y;
|
||||
|
||||
if constexpr (has_ids) {
|
||||
// experts + tiles of ncols_dst are packed in the y dimension
|
||||
int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
|
||||
const int nchannels_x = gridDim.y / col_tiles;
|
||||
const int tile_idx = blockIdx.y / nchannels_x;
|
||||
expert_idx = blockIdx.y - tile_idx * nchannels_x;
|
||||
col_base = tile_idx * cols_per_block;
|
||||
}
|
||||
|
||||
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
|
||||
const int channel_y = channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
|
|
@ -44,6 +55,14 @@ static __global__ void mul_mat_f(
|
|||
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
|
||||
|
||||
if constexpr (has_ids) {
|
||||
constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
|
||||
const int64_t col_offset = col_base;
|
||||
y += col_offset * stride_col_y * y_stride_scale;
|
||||
dst += col_offset * stride_col_dst;
|
||||
ids += col_offset * stride_row_id;
|
||||
}
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
|
|
@ -61,12 +80,17 @@ static __global__ void mul_mat_f(
|
|||
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
slot_map[j] = -1;
|
||||
}
|
||||
|
||||
if (col_base + j >= ncols_dst_total) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
||||
int match = id_row[k*stride_col_id] == expert_idx;
|
||||
|
||||
|
|
@ -108,7 +132,8 @@ static __global__ void mul_mat_f(
|
|||
if constexpr (!has_ids) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||
} else {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
|
|
@ -120,7 +145,8 @@ static __global__ void mul_mat_f(
|
|||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
} else {
|
||||
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
}
|
||||
|
|
@ -183,14 +209,14 @@ static __global__ void mul_mat_f(
|
|||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
} else {
|
||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||
if (slot >= 0) {
|
||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
|
|
@ -201,20 +227,23 @@ static __global__ void mul_mat_f(
|
|||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nchannels_dst,
|
||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
|
|
@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids(
|
|||
template <typename T, int cols_per_block>
|
||||
void mul_mat_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
|
|
@ -268,49 +298,49 @@ void mul_mat_f_cuda(
|
|||
switch (nwarps_best) {
|
||||
case 1: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
} break;
|
||||
|
|
@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (ncols_dst) {
|
||||
|
||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||
|
||||
GGML_ASSERT(ids || ncols_dst <= 16);
|
||||
|
||||
switch (ncols_case) {
|
||||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
} break;
|
||||
|
|
@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
|
||||
template void mul_mat_f_cuda<T, ncols_dst>( \
|
||||
const T * x, const float * y, const int32_t * ids, float * dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
|
||||
const int64_t stride_col_id, const int64_t stride_row_id, \
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
|
|
|
|||
|
|
@ -567,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
|||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
for (int idx = idx_start; idx < idx_end;) {
|
||||
for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
|
||||
const int res = ggml_metal_op_encode(ctx_op, idx);
|
||||
if (res == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
idx += res;
|
||||
idx += res - 1;
|
||||
}
|
||||
|
||||
ggml_metal_op_free(ctx_op);
|
||||
|
|
|
|||
|
|
@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
||||
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
|
||||
|
|
|
|||
|
|
@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
case GGML_OP_ARANGE:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
|
|
@ -717,8 +719,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction &&
|
||||
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@
|
|||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
|
|
|
|||
|
|
@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
|||
}
|
||||
|
||||
struct ggml_metal_op {
|
||||
ggml_metal_op(
|
||||
ggml_metal_device_t dev,
|
||||
ggml_metal_cmd_buf_t cmd_buf,
|
||||
ggml_cgraph * gf,
|
||||
int idx_start,
|
||||
int idx_end,
|
||||
bool use_fusion,
|
||||
bool use_concurrency,
|
||||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion) {
|
||||
this->dev = dev;
|
||||
this->lib = ggml_metal_device_get_library(dev);
|
||||
this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
|
||||
this->mem_ranges = ggml_mem_ranges_init(debug_graph);
|
||||
this->idx_start = idx_start;
|
||||
this->idx_end = idx_end;
|
||||
this->use_fusion = use_fusion;
|
||||
this->use_concurrency = use_concurrency;
|
||||
this->use_capture = use_capture;
|
||||
this->debug_graph = debug_graph;
|
||||
this->debug_fusion = debug_fusion;
|
||||
this->gf = gf;
|
||||
|
||||
idxs.reserve(gf->n_nodes);
|
||||
|
||||
// filter empty nodes
|
||||
// TODO: this can be removed when the allocator starts filtering them earlier
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
|
||||
for (int i = idx_start; i < idx_end; i++) {
|
||||
if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
|
||||
idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~ggml_metal_op() {
|
||||
ggml_metal_encoder_end_encoding(this->enc);
|
||||
ggml_metal_encoder_free(this->enc);
|
||||
ggml_mem_ranges_free(this->mem_ranges);
|
||||
}
|
||||
|
||||
int n_nodes() const {
|
||||
return idxs.size();
|
||||
}
|
||||
|
||||
ggml_tensor * node(int i) const {
|
||||
assert(i >= 0 && i < (int) idxs.size());
|
||||
return ggml_graph_node(gf, idxs[i]);
|
||||
}
|
||||
|
||||
bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
|
||||
assert(use_fusion);
|
||||
assert(i0 >= 0 && i0 < n_nodes());
|
||||
|
||||
if (i0 + n_ops > n_nodes()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
|
||||
}
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_library_t lib;
|
||||
ggml_metal_encoder_t enc;
|
||||
ggml_mem_ranges_t mem_ranges;
|
||||
|
||||
ggml_cgraph * gf;
|
||||
|
||||
int idx_start;
|
||||
int idx_end;
|
||||
|
||||
bool use_fusion;
|
||||
bool use_concurrency;
|
||||
bool use_capture;
|
||||
|
||||
int debug_graph;
|
||||
int debug_fusion;
|
||||
|
||||
private:
|
||||
ggml_cgraph * gf;
|
||||
|
||||
int idx_start;
|
||||
int idx_end;
|
||||
|
||||
// non-empty node indices
|
||||
std::vector<int> idxs;
|
||||
};
|
||||
|
||||
ggml_metal_op_t ggml_metal_op_init(
|
||||
|
|
@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init(
|
|||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion) {
|
||||
ggml_metal_op_t res = new ggml_metal_op();
|
||||
|
||||
*res = {
|
||||
/*.dev =*/ dev,
|
||||
/*.lib =*/ ggml_metal_device_get_library(dev),
|
||||
/*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency),
|
||||
/*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph),
|
||||
/*.gf =*/ gf,
|
||||
/*.idx_start =*/ idx_start,
|
||||
/*.idx_end =*/ idx_end,
|
||||
/*.use_fusion =*/ use_fusion,
|
||||
/*.use_concurrency =*/ use_concurrency,
|
||||
/*.use_capture =*/ use_capture,
|
||||
/*.debug_graph =*/ debug_graph,
|
||||
/*.debug_fusion =*/ debug_fusion,
|
||||
};
|
||||
ggml_metal_op_t res = new ggml_metal_op(
|
||||
dev,
|
||||
cmd_buf,
|
||||
gf,
|
||||
idx_start,
|
||||
idx_end,
|
||||
use_fusion,
|
||||
use_concurrency,
|
||||
use_capture,
|
||||
debug_graph,
|
||||
debug_fusion);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx) {
|
||||
ggml_metal_encoder_end_encoding(ctx->enc);
|
||||
ggml_metal_encoder_free(ctx->enc);
|
||||
ggml_mem_ranges_free(ctx->mem_ranges);
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
|
||||
return ctx->n_nodes();
|
||||
}
|
||||
|
||||
static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
|
||||
if (!ctx->mem_ranges) {
|
||||
return true;
|
||||
|
|
@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor
|
|||
}
|
||||
|
||||
static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
struct ggml_cgraph * gf = ctx->gf;
|
||||
|
||||
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
||||
struct ggml_tensor * node = nodes[0];
|
||||
struct ggml_tensor * node = ctx->node(idx);
|
||||
|
||||
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
||||
|
||||
|
|
@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// noop -> next node
|
||||
if (ctx->debug_graph > 0) {
|
||||
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
|
||||
}
|
||||
} return 1;
|
||||
default:
|
||||
{
|
||||
|
|
@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
// update the mem ranges in the encoding context
|
||||
for (int i = 0; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) {
|
||||
if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
}
|
||||
|
|
@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
|
||||
if (ctx->use_capture) {
|
||||
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx)));
|
||||
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
|
||||
}
|
||||
|
||||
int res = ggml_metal_op_encode_impl(ctx, idx);
|
||||
if (idx + res > ctx->idx_end) {
|
||||
if (idx + res > ctx->n_nodes()) {
|
||||
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
||||
}
|
||||
|
|
@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1476,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
|
@ -1589,8 +1631,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1612,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
|
||||
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
|
||||
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
|
|
@ -1631,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
(ne21 >= ne21_mm_id_min)) {
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
// extra buffers for intermediate id mapping
|
||||
ggml_metal_buffer_id bid_tpe = bid_dst;
|
||||
|
|
@ -1687,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
|
@ -1783,8 +1818,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -1856,8 +1890,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2176,16 +2209,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
|
||||
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int idx_end = ctx->idx_end;
|
||||
|
||||
const bool use_fusion = ctx->use_fusion;
|
||||
|
||||
const int debug_fusion = ctx->debug_fusion;
|
||||
|
|
@ -2258,22 +2286,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
||||
// across splits. idx_end indicates the last node in the current split
|
||||
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
||||
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
|
||||
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
||||
ggml_tensor * f0 = ctx->node(idx + n_fuse);
|
||||
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
|
||||
|
||||
if (f0 != f1->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
// b[0] === b[1] === ...
|
||||
if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) {
|
||||
if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only fuse ops if src1 is in the same Metal buffer
|
||||
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
||||
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
|
||||
if (bid_fuse.metal != bid_src1.metal) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -2309,10 +2340,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
break;
|
||||
|
|
@ -2344,8 +2375,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2393,8 +2423,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2445,20 +2474,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int idx_end = ctx->idx_end;
|
||||
|
||||
const bool use_fusion = ctx->use_fusion;
|
||||
|
||||
const int debug_fusion = ctx->debug_fusion;
|
||||
|
||||
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
|
|
@ -2499,38 +2523,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|||
fops[1] = GGML_OP_MUL;
|
||||
fops[2] = GGML_OP_ADD;
|
||||
|
||||
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
||||
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
|
||||
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
||||
ggml_tensor * f0 = ctx->node(idx + n_fuse);
|
||||
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
|
||||
|
||||
if (f0 != f1->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
||||
if (f1->src[1]->ne[0] != op->ne[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
||||
if (!ggml_is_contiguous_rows(f1->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
||||
if (f1->type != GGML_TYPE_F32) {
|
||||
break;
|
||||
}
|
||||
|
||||
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
||||
//ctx->fuse_cnt[f1->op]++;
|
||||
|
||||
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
||||
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
|
||||
|
||||
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
||||
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
||||
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
||||
args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
|
||||
args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
|
||||
args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
|
||||
|
||||
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
||||
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
||||
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
||||
args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
|
||||
args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
|
||||
args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
|
||||
}
|
||||
|
||||
++n_fuse;
|
||||
|
|
@ -2546,10 +2573,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
break;
|
||||
|
|
@ -2585,8 +2612,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2681,8 +2707,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2752,8 +2777,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2798,8 +2822,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2852,8 +2875,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2897,8 +2919,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2944,8 +2965,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -2985,8 +3005,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -3020,8 +3039,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -3060,8 +3078,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
@ -3103,8 +3120,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init(
|
|||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ using namespace metal;
|
|||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||
typedef matrix<bfloat, 2, 4> bfloat2x4;
|
||||
#endif
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
|
|
@ -7856,6 +7857,9 @@ kernel void kernel_set_rows_f(
|
|||
}
|
||||
}
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
|
|
@ -7868,7 +7872,7 @@ kernel void kernel_set_rows_f(
|
|||
#define SG_MAT_ROW 8
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * src0,
|
||||
|
|
@ -7879,8 +7883,8 @@ kernel void kernel_mul_mm(
|
|||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
|
|
@ -7894,8 +7898,9 @@ kernel void kernel_mul_mm(
|
|||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
|
|
@ -7913,27 +7918,45 @@ kernel void kernel_mul_mm(
|
|||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const float * y = (device const float *)(src1
|
||||
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
+ args.nb10*iy);
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
|
@ -7942,8 +7965,8 @@ kernel void kernel_mul_mm(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
|
|
@ -7971,7 +7994,8 @@ kernel void kernel_mul_mm(
|
|||
}
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
|
||||
if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) {
|
||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
|
|
@ -8076,7 +8100,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_
|
|||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
|
|
@ -8090,8 +8114,8 @@ kernel void kernel_mul_mm_id(
|
|||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
|
|
@ -8114,8 +8138,9 @@ kernel void kernel_mul_mm_id(
|
|||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_half8x8 mb[2];
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
|
|
@ -8136,27 +8161,45 @@ kernel void kernel_mul_mm_id(
|
|||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
device const float * y = (device const float *)(src1
|
||||
const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL));
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
+ args.nb10*iy);
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0;
|
||||
}
|
||||
} else {
|
||||
*(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
|
@ -8165,8 +8208,8 @@ kernel void kernel_mul_mm_id(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
|
|
@ -8299,66 +8342,117 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne
|
|||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
|
|
|||
|
|
@ -19,12 +19,14 @@
|
|||
|
||||
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
|
||||
// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
|
||||
// to avoid conflicts with applications or other libraries who might use it.
|
||||
namespace vk::detail { class DispatchLoaderDynamic; }
|
||||
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||
VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
|
|
@ -422,6 +424,8 @@ struct vk_device_struct {
|
|||
bool subgroup_ballot;
|
||||
bool subgroup_clustered;
|
||||
bool multi_add;
|
||||
bool shader_int64;
|
||||
bool buffer_device_address;
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
|
|
@ -669,6 +673,7 @@ struct vk_buffer_struct {
|
|||
vk::MemoryPropertyFlags memory_property_flags;
|
||||
void * ptr;
|
||||
size_t size = 0;
|
||||
vk::DeviceAddress bda_addr {};
|
||||
|
||||
vk_device device;
|
||||
|
||||
|
|
@ -1001,6 +1006,7 @@ struct vk_op_argsort_push_constants {
|
|||
};
|
||||
|
||||
struct vk_op_im2col_push_constants {
|
||||
uint64_t dst_addr;
|
||||
uint32_t batch_offset; uint32_t offset_delta;
|
||||
uint32_t IC;
|
||||
uint32_t IW; uint32_t IH;
|
||||
|
|
@ -1014,6 +1020,7 @@ struct vk_op_im2col_push_constants {
|
|||
};
|
||||
|
||||
struct vk_op_im2col_3d_push_constants {
|
||||
uint64_t dst_addr;
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
|
|
@ -2026,10 +2033,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
return buf;
|
||||
}
|
||||
|
||||
vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
|
||||
vk::MemoryAllocateFlags mem_flags {};
|
||||
if (device->buffer_device_address) {
|
||||
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
|
||||
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
|
||||
}
|
||||
|
||||
vk::BufferCreateInfo buffer_create_info{
|
||||
vk::BufferCreateFlags(),
|
||||
size,
|
||||
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
|
||||
usage_flags,
|
||||
vk::SharingMode::eExclusive,
|
||||
0,
|
||||
nullptr,
|
||||
|
|
@ -2041,6 +2055,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
|
||||
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
|
||||
|
||||
const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
|
||||
|
||||
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
|
||||
const auto & req_flags = *it;
|
||||
|
||||
|
|
@ -2052,7 +2068,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
buf->memory_property_flags = req_flags;
|
||||
|
||||
try {
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
|
||||
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
|
||||
break;
|
||||
} catch (const vk::SystemError& e) {
|
||||
// loop and retry
|
||||
|
|
@ -2080,6 +2096,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
|||
buf->device = device;
|
||||
buf->size = size;
|
||||
|
||||
if (device->buffer_device_address) {
|
||||
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
|
||||
buf->bda_addr = device->device.getBufferAddress(addressInfo);
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
device->memory_logger->log_allocation(buf, size);
|
||||
#endif
|
||||
|
|
@ -3546,14 +3567,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
if (device->float_controls_rte_fp16) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
#define IM2COL(bda) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
if (device->float_controls_rte_fp16) { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
|
||||
}
|
||||
if (device->shader_int64 && device->buffer_device_address) {
|
||||
IM2COL(_bda)
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
|
||||
IM2COL()
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||
|
|
@ -4045,6 +4072,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->vendor_id != VK_VENDOR_ID_INTEL &&
|
||||
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
|
||||
|
||||
device->shader_int64 = device_features2.features.shaderInt64;
|
||||
device->buffer_device_address = vk12_features.bufferDeviceAddress;
|
||||
|
||||
if (device->subgroup_size_control) {
|
||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
||||
|
|
@ -4538,6 +4568,12 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
|||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
||||
|
||||
static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
|
||||
|
||||
vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
|
||||
return ggml_vk_default_dispatcher_instance;
|
||||
}
|
||||
|
||||
static void ggml_vk_instance_init() {
|
||||
if (vk_instance_initialized) {
|
||||
return;
|
||||
|
|
@ -4545,7 +4581,7 @@ static void ggml_vk_instance_init() {
|
|||
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
||||
|
||||
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||
VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr);
|
||||
ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);
|
||||
|
||||
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||
|
||||
|
|
@ -5683,8 +5719,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|||
ggml_vk_queue_command_pools_cleanup(dst->device);
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
||||
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
|
||||
|
||||
if (disable_split_k) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
uint32_t split_k = 1;
|
||||
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
|
||||
|
|
@ -6009,7 +6049,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
|
||||
|
|
@ -6027,8 +6067,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
const uint64_t ne12 = src1->ne[2];
|
||||
const uint64_t ne13 = src1->ne[3];
|
||||
|
||||
const uint64_t ne20 = dst->ne[0];
|
||||
const uint64_t ne21 = dst->ne[1];
|
||||
const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
|
||||
const uint32_t stride_batch_d = stride_d*ne21;
|
||||
|
||||
const uint64_t r2 = ne12 / ne02;
|
||||
const uint64_t r3 = ne13 / ne03;
|
||||
|
|
@ -6097,7 +6138,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
const int y_ne = padded_n * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
||||
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
|
||||
|
||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||
|
|
@ -6256,13 +6297,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
|
||||
}
|
||||
|
||||
// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
|
||||
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
|
||||
|
||||
// compute
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, pipeline,
|
||||
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
|
||||
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
||||
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
||||
ne01, ne11, ne10,
|
||||
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
|
||||
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
|
||||
); // NOLINT
|
||||
|
||||
|
|
@ -6740,9 +6784,36 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
|
||||
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
||||
|
||||
// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
|
||||
// where the M dimension is very large.
|
||||
// Split_k doesn't work with M splitting.
|
||||
const size_t nbytes = ggml_nbytes(src0);
|
||||
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
|
||||
if (needs_split) {
|
||||
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
|
||||
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
|
||||
uint32_t m_offset = 0;
|
||||
while (m_offset < dst->ne[0]) {
|
||||
const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
|
||||
ggml_tensor dst2 = *dst;
|
||||
ggml_tensor src02 = *src0;
|
||||
|
||||
dst2.view_src = dst->view_src ? dst->view_src : dst;
|
||||
src02.view_src = src0->view_src ? src0->view_src : src0;
|
||||
|
||||
dst2.view_offs += m_offset * dst->nb[0];
|
||||
src02.view_offs += m_offset * src0->nb[1];
|
||||
dst2.ne[0] = cur_M_size;
|
||||
src02.ne[1] = cur_M_size;
|
||||
|
||||
ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun);
|
||||
|
||||
m_offset += cur_M_size;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
||||
// detect 0213 permutation, and batch size of 1
|
||||
src0->nb[0] <= src0->nb[2] &&
|
||||
src0->nb[2] <= src0->nb[1] &&
|
||||
|
|
@ -6762,7 +6833,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||
} else {
|
||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -8622,6 +8693,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
|
||||
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
|
||||
// buffer device address path doesn't use dst buffer
|
||||
d_sz = 1;
|
||||
}
|
||||
// im2col uses only src1 and dst buffers
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||
|
|
@ -9473,7 +9548,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
|
||||
const uint32_t pelements = OW * KW * KH;
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
||||
|
||||
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
||||
|
||||
ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
|
||||
dst_addr,
|
||||
batch_offset, offset_delta,
|
||||
IC, IW, IH, OW, OH, KW, KH,
|
||||
pelements,
|
||||
|
|
@ -9509,8 +9590,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
||||
|
||||
const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
|
||||
|
||||
vk_op_im2col_3d_push_constants pc {};
|
||||
|
||||
pc.dst_addr = dst_addr;
|
||||
pc.nb10 = nb10 / ggml_type_size(src1->type);
|
||||
pc.nb11 = nb11 / ggml_type_size(src1->type);
|
||||
pc.nb12 = nb12 / ggml_type_size(src1->type);
|
||||
|
|
@ -10697,10 +10784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
|
||||
ctx->semaphore_idx = 0;
|
||||
|
||||
const ggml_tensor * src0 = node->src[0];
|
||||
const ggml_tensor * src1 = node->src[1];
|
||||
const ggml_tensor * src2 = node->src[2];
|
||||
const ggml_tensor * src3 = node->src[3];
|
||||
ggml_tensor * src0 = node->src[0];
|
||||
ggml_tensor * src1 = node->src[1];
|
||||
ggml_tensor * src2 = node->src[2];
|
||||
ggml_tensor * src3 = node->src[3];
|
||||
|
||||
switch (node->op) {
|
||||
// Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
|
||||
|
|
|
|||
|
|
@ -117,6 +117,9 @@ void main() {
|
|||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
|
|
@ -155,7 +158,11 @@ void main() {
|
|||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -172,8 +179,11 @@ void main() {
|
|||
|
||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
rowmaxf[r] = Sf[r][0];
|
||||
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
||||
}
|
||||
Moldf[r] = Mf[r];
|
||||
|
|
@ -190,6 +200,9 @@ void main() {
|
|||
// Compute sum across row of P
|
||||
rowsumf[r] = 0.0;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowsumf[r] += Pf[r][c];
|
||||
}
|
||||
|
||||
|
|
@ -203,6 +216,9 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
|
|||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
const bool KV_bounds_check = Clamp != 0;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
|
@ -65,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
|||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -152,14 +152,17 @@ void main() {
|
|||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
}
|
||||
|
|
@ -202,7 +205,9 @@ void main() {
|
|||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -210,8 +215,11 @@ void main() {
|
|||
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
||||
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||
}
|
||||
float Moldf = Mf[r];
|
||||
|
|
@ -233,6 +241,9 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
float Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@
|
|||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint batch_offset; uint offset_delta;
|
||||
uint IC;
|
||||
uint IW; uint IH;
|
||||
|
|
@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
|
|||
int d0; int d1;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
||||
|
|
@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
|
||||
|
|
@ -38,7 +43,7 @@ void main() {
|
|||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
|
||||
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
||||
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
||||
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
|
||||
const int oh_s1 = int(oh) * p.s1;
|
||||
const uint ksize = p.OW * p.KH;
|
||||
|
||||
|
|
@ -50,7 +55,7 @@ void main() {
|
|||
uint current_ix = rem % p.OW;
|
||||
|
||||
A_TYPE values[NUM_ITER];
|
||||
uint offset_dst[NUM_ITER];
|
||||
BDA_OFFSET_T offset_dst[NUM_ITER];
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
values[idx] = A_TYPE(0);
|
||||
}
|
||||
|
|
@ -66,7 +71,7 @@ void main() {
|
|||
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
|
||||
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
|
||||
|
||||
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
|
||||
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
|
||||
|
||||
if ((iih < p.IH) && (iiw < p.IW)) {
|
||||
values[idx] = data_a[src_base + iih * p.IW + iiw];
|
||||
|
|
@ -89,7 +94,11 @@ void main() {
|
|||
continue;
|
||||
}
|
||||
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
|
||||
dst_addr.d = D_TYPE(values[idx]);
|
||||
#else
|
||||
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,8 +6,11 @@
|
|||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
|
|
@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
|
|||
uint32_t misalign_offsets;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
|
|
@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint32_t i = gl_GlobalInvocationID.x;
|
||||
|
||||
|
|
@ -100,13 +105,22 @@ void main() {
|
|||
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const uint32_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
dst_addr.d = D_TYPE(0.0f);
|
||||
} else {
|
||||
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#else
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
|
||||
} else {
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -265,7 +265,6 @@ void main() {
|
|||
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
#if QUANT_K > 1
|
||||
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
||||
|
|
@ -281,6 +280,8 @@ void main() {
|
|||
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
||||
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
|
||||
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
|
|
|
|||
|
|
@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
|
|||
return uintBitsToFloat(bits);
|
||||
}
|
||||
|
||||
#if BDA
|
||||
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
|
||||
|
||||
#define BDA_STORAGE_T uint64_t
|
||||
#define BDA_OFFSET_T uint64_t
|
||||
|
||||
#else
|
||||
|
||||
#define BDA_STORAGE_T uvec2
|
||||
#define BDA_OFFSET_T uint
|
||||
|
||||
#endif
|
||||
|
||||
#endif // !defined(GGML_TYPES_COMP)
|
||||
|
|
|
|||
|
|
@ -789,13 +789,15 @@ void process_shaders() {
|
|||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
|
||||
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
|
||||
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
for (std::string dim_str : {"", "_3d"}) {
|
||||
for (bool bda : {false, true}) {
|
||||
std::string bda_str = bda ? "_bda" : "";
|
||||
std::string bda_def = bda ? "1" : "0";
|
||||
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
|
|
|||
|
|
@ -3703,6 +3703,7 @@ struct ggml_tensor * ggml_set_rows(
|
|||
result->op = GGML_OP_SET_ROWS;
|
||||
result->src[0] = b;
|
||||
result->src[1] = c;
|
||||
result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -708,6 +708,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
embd.push_back(id);
|
||||
|
||||
if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
|
||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
// echo this to console
|
||||
input_echo = true;
|
||||
|
||||
|
|
@ -825,11 +829,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// if current token is not EOG, we add it to current assistant message
|
||||
if (params.conversation_mode && !waiting_for_first_input) {
|
||||
const auto id = common_sampler_last(smpl);
|
||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||
|
||||
if (!prompt.empty()) {
|
||||
prompt.clear();
|
||||
is_interacting = false;
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -39,6 +39,7 @@
|
|||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.225 0 0);
|
||||
--code-foreground: oklch(0.875 0 0);
|
||||
--layer-popover: 1000000;
|
||||
}
|
||||
|
||||
.dark {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte';
|
||||
|
||||
interface Props {
|
||||
message: DatabaseMessage;
|
||||
role: 'user' | 'assistant';
|
||||
justify: 'start' | 'end';
|
||||
actionsPosition: 'left' | 'right';
|
||||
|
|
@ -29,7 +28,6 @@
|
|||
actionsPosition,
|
||||
deletionInfo,
|
||||
justify,
|
||||
message,
|
||||
onCopy,
|
||||
onEdit,
|
||||
onConfirmDelete,
|
||||
|
|
@ -49,26 +47,17 @@
|
|||
</script>
|
||||
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-{justify}">
|
||||
<div
|
||||
class="hidden items-center text-xs text-muted-foreground transition-opacity md:flex md:group-hover:opacity-0"
|
||||
>
|
||||
{new Date(message.timestamp).toLocaleTimeString(undefined, {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
})}
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="absolute top-0 {actionsPosition === 'left'
|
||||
? 'left-0'
|
||||
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity md:opacity-0 md:group-hover:opacity-100"
|
||||
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity"
|
||||
>
|
||||
{#if siblingInfo && siblingInfo.totalSiblings > 1}
|
||||
<ChatMessageBranchingControls {siblingInfo} {onNavigateToSibling} />
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-100 transition-all duration-150 md:pointer-events-none md:opacity-0 md:group-hover:pointer-events-auto md:group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-100 transition-all duration-150"
|
||||
>
|
||||
<ActionButton icon={Copy} tooltip="Copy" onclick={onCopy} />
|
||||
|
||||
|
|
|
|||
|
|
@ -138,7 +138,6 @@
|
|||
|
||||
{#if message.timestamp && !isEditing}
|
||||
<ChatMessageActions
|
||||
{message}
|
||||
role="assistant"
|
||||
justify="start"
|
||||
actionsPosition="left"
|
||||
|
|
|
|||
|
|
@ -135,7 +135,6 @@
|
|||
actionsPosition="right"
|
||||
{deletionInfo}
|
||||
justify="end"
|
||||
{message}
|
||||
{onConfirmDelete}
|
||||
{onCopy}
|
||||
{onDelete}
|
||||
|
|
|
|||
|
|
@ -362,7 +362,8 @@
|
|||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100vh] flex-col gap-0 rounded-none p-0 md:h-[64vh] md:rounded-lg"
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
>
|
||||
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
|
||||
|
|
@ -441,7 +442,7 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<ScrollArea class="max-h-[calc(100vh-13.5rem)] flex-1">
|
||||
<ScrollArea class="max-h-[calc(100dvh-13.5rem)] flex-1 md:max-h-[calc(100vh-13.5rem)]">
|
||||
<div class="space-y-6 p-4 md:p-6">
|
||||
<div>
|
||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||
|
|
@ -456,7 +457,6 @@
|
|||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
isMobile={false}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -13,10 +13,9 @@
|
|||
localConfig: SettingsConfigType;
|
||||
onConfigChange: (key: string, value: string | boolean) => void;
|
||||
onThemeChange?: (theme: string) => void;
|
||||
isMobile?: boolean;
|
||||
}
|
||||
|
||||
let { fields, localConfig, onConfigChange, onThemeChange, isMobile = false }: Props = $props();
|
||||
let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#each fields as field (field.key)}
|
||||
|
|
@ -28,10 +27,10 @@
|
|||
|
||||
<Input
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class={isMobile ? 'w-full' : 'max-w-md'}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="w-full md:max-w-md"
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
|
|
@ -45,10 +44,10 @@
|
|||
|
||||
<Textarea
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
value={String(localConfig[field.key] ?? '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class={isMobile ? 'min-h-[100px] w-full' : 'min-h-[100px] max-w-2xl'}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`}
|
||||
class="min-h-[100px] w-full md:max-w-2xl"
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
|
|
@ -76,7 +75,7 @@
|
|||
}
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class={isMobile ? 'w-full' : 'max-w-md'}>
|
||||
<Select.Trigger class="w-full md:w-auto md:max-w-md">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
|
|
@ -109,6 +108,7 @@
|
|||
{/if}
|
||||
{:else if field.type === 'checkbox'}
|
||||
{@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()}
|
||||
|
||||
<div class="flex items-start space-x-3">
|
||||
<Checkbox
|
||||
id={field.key}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
|
||||
interface Props {
|
||||
onReset?: () => void;
|
||||
|
|
@ -8,8 +9,15 @@
|
|||
|
||||
let { onReset, onSave }: Props = $props();
|
||||
|
||||
function handleReset() {
|
||||
let showResetDialog = $state(false);
|
||||
|
||||
function handleResetClick() {
|
||||
showResetDialog = true;
|
||||
}
|
||||
|
||||
function handleConfirmReset() {
|
||||
onReset?.();
|
||||
showResetDialog = false;
|
||||
}
|
||||
|
||||
function handleSave() {
|
||||
|
|
@ -18,7 +26,23 @@
|
|||
</script>
|
||||
|
||||
<div class="flex justify-between border-t border-border/30 p-6">
|
||||
<Button variant="outline" onclick={handleReset}>Reset to default</Button>
|
||||
<Button variant="outline" onclick={handleResetClick}>Reset to default</Button>
|
||||
|
||||
<Button onclick={handleSave}>Save settings</Button>
|
||||
</div>
|
||||
|
||||
<AlertDialog.Root bind:open={showResetDialog}>
|
||||
<AlertDialog.Content>
|
||||
<AlertDialog.Header>
|
||||
<AlertDialog.Title>Reset Settings to Default</AlertDialog.Title>
|
||||
<AlertDialog.Description>
|
||||
Are you sure you want to reset all settings to their default values? This action cannot be
|
||||
undone and will permanently remove all your custom configurations.
|
||||
</AlertDialog.Description>
|
||||
</AlertDialog.Header>
|
||||
<AlertDialog.Footer>
|
||||
<AlertDialog.Cancel>Cancel</AlertDialog.Cancel>
|
||||
<AlertDialog.Action onclick={handleConfirmReset}>Reset to Default</AlertDialog.Action>
|
||||
</AlertDialog.Footer>
|
||||
</AlertDialog.Content>
|
||||
</AlertDialog.Root>
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@
|
|||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each filteredConversations as conversation (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1" onclick={handleMobileSidebarItemClick}>
|
||||
<Sidebar.MenuItem class="mb-1">
|
||||
<ChatSidebarConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
|
|
@ -95,6 +95,7 @@
|
|||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode
|
||||
}}
|
||||
{handleMobileSidebarItemClick}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={editConversation}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
interface Props {
|
||||
isActive?: boolean;
|
||||
conversation: DatabaseConversation;
|
||||
handleMobileSidebarItemClick?: () => void;
|
||||
onDelete?: (id: string) => void;
|
||||
onEdit?: (id: string, name: string) => void;
|
||||
onSelect?: (id: string) => void;
|
||||
|
|
@ -16,6 +17,7 @@
|
|||
|
||||
let {
|
||||
conversation,
|
||||
handleMobileSidebarItemClick,
|
||||
onDelete,
|
||||
onEdit,
|
||||
onSelect,
|
||||
|
|
@ -47,6 +49,7 @@
|
|||
|
||||
function handleConfirmEdit() {
|
||||
if (!editedName.trim()) return;
|
||||
showEditDialog = false;
|
||||
onEdit?.(conversation.id, editedName);
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +88,12 @@
|
|||
: ''}"
|
||||
onclick={handleSelect}
|
||||
>
|
||||
<div class="text flex min-w-0 flex-1 items-center space-x-3">
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="text flex min-w-0 flex-1 items-center space-x-3"
|
||||
onclick={handleMobileSidebarItemClick}
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<p class="truncate text-sm font-medium">{conversation.name}</p>
|
||||
|
||||
|
|
@ -178,5 +186,10 @@
|
|||
&:is(:hover) :global([data-slot='dropdown-menu-trigger']) {
|
||||
opacity: 1;
|
||||
}
|
||||
@media (max-width: 768px) {
|
||||
:global([data-slot='dropdown-menu-trigger']) {
|
||||
opacity: 1 !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@
|
|||
<DropdownMenu.Root bind:open>
|
||||
<DropdownMenu.Trigger
|
||||
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{#if triggerTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
|
|
@ -53,7 +54,7 @@
|
|||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content {align} class="z-999 w-48">
|
||||
<DropdownMenu.Content {align} class="z-[999999] w-48">
|
||||
{#each actions as action, index (action.label)}
|
||||
{#if action.separator && index > 0}
|
||||
<DropdownMenu.Separator />
|
||||
|
|
|
|||
|
|
@ -19,7 +19,15 @@
|
|||
bind:ref
|
||||
data-slot="alert-dialog-content"
|
||||
class={cn(
|
||||
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
|
||||
'fixed z-[999999] grid w-full gap-4 border bg-background p-6 shadow-lg duration-200',
|
||||
// Mobile: Bottom sheet behavior
|
||||
'right-0 bottom-0 left-0 max-h-[100dvh] translate-x-0 translate-y-0 overflow-y-auto rounded-t-lg',
|
||||
'data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:slide-out-to-bottom-full',
|
||||
'data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:slide-in-from-bottom-full',
|
||||
// Desktop: Centered dialog behavior
|
||||
'sm:top-[50%] sm:right-auto sm:bottom-auto sm:left-[50%] sm:max-h-[100vh] sm:max-w-lg sm:translate-x-[-50%] sm:translate-y-[-50%] sm:rounded-lg',
|
||||
'sm:data-[state=closed]:slide-out-to-bottom-0 sm:data-[state=closed]:zoom-out-95',
|
||||
'sm:data-[state=open]:slide-in-from-bottom-0 sm:data-[state=open]:zoom-in-95',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@
|
|||
<div
|
||||
bind:this={ref}
|
||||
data-slot="alert-dialog-footer"
|
||||
class={cn('flex flex-col-reverse gap-2 sm:flex-row sm:justify-end', className)}
|
||||
class={cn(
|
||||
'mt-6 flex flex-row gap-2 sm:mt-0 sm:justify-end [&>*]:flex-1 sm:[&>*]:flex-none',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
bind:ref
|
||||
data-slot="dialog-content"
|
||||
class={cn(
|
||||
'fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg',
|
||||
`fixed top-[50%] left-[50%] z-50 grid max-h-[100dvh] w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-lg border border-border/30 bg-background p-6 shadow-lg duration-200 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95 sm:max-w-lg md:max-h-[100vh]`,
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { onDestroy, onMount } from 'svelte';
|
||||
import { Select as SelectPrimitive } from 'bits-ui';
|
||||
import SelectScrollUpButton from './select-scroll-up-button.svelte';
|
||||
import SelectScrollDownButton from './select-scroll-down-button.svelte';
|
||||
|
|
@ -14,6 +15,76 @@
|
|||
}: WithoutChild<SelectPrimitive.ContentProps> & {
|
||||
portalProps?: SelectPrimitive.PortalProps;
|
||||
} = $props();
|
||||
|
||||
let cleanupInternalListeners: (() => void) | undefined;
|
||||
|
||||
onMount(() => {
|
||||
const listenerOptions: AddEventListenerOptions = { passive: false };
|
||||
|
||||
const blockOutsideWheel = (event: WheelEvent) => {
|
||||
if (!ref) {
|
||||
return;
|
||||
}
|
||||
|
||||
const target = event.target as Node | null;
|
||||
|
||||
if (!target || !ref.contains(target)) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
};
|
||||
|
||||
const blockOutsideTouchMove = (event: TouchEvent) => {
|
||||
if (!ref) {
|
||||
return;
|
||||
}
|
||||
|
||||
const target = event.target as Node | null;
|
||||
|
||||
if (!target || !ref.contains(target)) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener('wheel', blockOutsideWheel, listenerOptions);
|
||||
document.addEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('wheel', blockOutsideWheel, listenerOptions);
|
||||
document.removeEventListener('touchmove', blockOutsideTouchMove, listenerOptions);
|
||||
};
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
const element = ref;
|
||||
|
||||
cleanupInternalListeners?.();
|
||||
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
|
||||
const stopWheelPropagation = (event: WheelEvent) => {
|
||||
event.stopPropagation();
|
||||
};
|
||||
|
||||
const stopTouchPropagation = (event: TouchEvent) => {
|
||||
event.stopPropagation();
|
||||
};
|
||||
|
||||
element.addEventListener('wheel', stopWheelPropagation);
|
||||
element.addEventListener('touchmove', stopTouchPropagation);
|
||||
|
||||
cleanupInternalListeners = () => {
|
||||
element.removeEventListener('wheel', stopWheelPropagation);
|
||||
element.removeEventListener('touchmove', stopTouchPropagation);
|
||||
};
|
||||
});
|
||||
|
||||
onDestroy(() => {
|
||||
cleanupInternalListeners?.();
|
||||
});
|
||||
</script>
|
||||
|
||||
<SelectPrimitive.Portal {...portalProps}>
|
||||
|
|
@ -22,7 +93,7 @@
|
|||
{sideOffset}
|
||||
data-slot="select-content"
|
||||
class={cn(
|
||||
'relative z-50 max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
|
||||
'relative z-[var(--layer-popover,1000000)] max-h-(--bits-select-content-available-height) min-w-[8rem] origin-(--bits-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md data-[side=bottom]:translate-y-1 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:-translate-x-1 data-[side=left]:slide-in-from-right-2 data-[side=right]:translate-x-1 data-[side=right]:slide-in-from-left-2 data-[side=top]:-translate-y-1 data-[side=top]:slide-in-from-bottom-2 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[state=open]:animate-in data-[state=open]:fade-in-0 data-[state=open]:zoom-in-95',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue