mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 18:09:42 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .devops/full-cuda.Dockerfile # .devops/llama-cli-cann.Dockerfile # .devops/llama-cli-cuda.Dockerfile # .devops/llama-cli-intel.Dockerfile # .devops/llama-cli-musa.Dockerfile # .devops/llama-cli-vulkan.Dockerfile # .devops/llama-server-cuda.Dockerfile # .devops/llama-server-intel.Dockerfile # .devops/llama-server-musa.Dockerfile # .devops/llama-server-vulkan.Dockerfile # .gitignore # CMakeLists.txt # Makefile # cmake/llama-config.cmake.in # docs/backend/SYCL.md # docs/build.md # examples/llama-bench/llama-bench.cpp # flake.lock # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-backend.cpp # ggml/src/ggml-blas/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-cpu/ggml-cpu.c # ggml/src/ggml-cuda/CMakeLists.txt # ggml/src/ggml-hip/CMakeLists.txt # ggml/src/ggml-metal/CMakeLists.txt # ggml/src/ggml-musa/CMakeLists.txt # ggml/src/ggml-sycl/CMakeLists.txt # scripts/sync-ggml.last # tests/test-backend-ops.cpp
This commit is contained in:
commit
091a432cf6
38 changed files with 167475 additions and 136736 deletions
161
.clang-format
Normal file
161
.clang-format
Normal file
|
@ -0,0 +1,161 @@
|
|||
---
|
||||
Language: Cpp
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignArrayOfStructures: Left
|
||||
AlignConsecutiveAssignments: AcrossComments
|
||||
AlignConsecutiveBitFields: AcrossComments
|
||||
AlignConsecutiveDeclarations: AcrossComments
|
||||
AlignConsecutiveMacros: AcrossComments
|
||||
# AlignConsecutiveShortCaseStatements: AcrossComments
|
||||
AlignEscapedNewlines: Left # LeftWithLastLine
|
||||
AlignOperands: Align
|
||||
AlignTrailingComments:
|
||||
Kind: Always
|
||||
OverEmptyLines: 1
|
||||
AllowAllArgumentsOnNextLine: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: false
|
||||
# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
|
||||
AllowShortBlocksOnASingleLine: Never
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: Inline
|
||||
AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: Inline
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true # OnePerLine
|
||||
BitFieldColonSpacing: Both
|
||||
BreakBeforeBraces: Custom # Attach
|
||||
BraceWrapping:
|
||||
AfterCaseLabel: true
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
AfterExternBlock: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
BeforeLambdaBody: false
|
||||
BeforeWhile: false
|
||||
IndentBraces: false
|
||||
SplitEmptyFunction: false
|
||||
SplitEmptyRecord: false
|
||||
SplitEmptyNamespace: false
|
||||
# BreakAdjacentStringLiterals: true
|
||||
BreakAfterAttributes: Never
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeInlineASMColon: OnlyMultiline
|
||||
BreakBeforeTernaryOperators: false
|
||||
# BreakBinaryOperations: Never
|
||||
BreakConstructorInitializers: AfterColon
|
||||
# BreakFunctionDefinitionParameters: false
|
||||
BreakInheritanceList: AfterComma
|
||||
BreakStringLiterals: true
|
||||
# BreakTemplateDeclarations: Yes
|
||||
ColumnLimit: 120
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: false
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
EmptyLineBeforeAccessModifier: Leave
|
||||
EmptyLineAfterAccessModifier: Never
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
IncludeBlocks: Regroup
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
SortPriority: 0
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
SortPriority: 0
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
SortPriority: 0
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IncludeIsMainSourceRegex: ''
|
||||
IndentAccessModifiers: false
|
||||
IndentCaseBlocks: true
|
||||
IndentCaseLabels: true
|
||||
IndentExternBlock: NoIndent
|
||||
IndentGotoLabels: false
|
||||
IndentPPDirectives: AfterHash
|
||||
IndentWidth: 4
|
||||
IndentWrappedFunctionNames: false
|
||||
InsertBraces: true # NOTE: may lead to incorrect formatting
|
||||
InsertNewlineAtEOF: true
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
LambdaBodyIndentation: Signature
|
||||
LineEnding: LF
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBinPackProtocolList: Auto
|
||||
ObjCBlockIndentWidth: 4
|
||||
ObjCSpaceAfterProperty: true
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PPIndentWidth: -1
|
||||
PackConstructorInitializers: CurrentLine
|
||||
PenaltyBreakAssignment: 2
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyBreakTemplateDeclaration: 10
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Middle
|
||||
QualifierAlignment: Left
|
||||
#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
|
||||
RawStringFormats:
|
||||
- Language: Cpp
|
||||
Delimiters:
|
||||
- cc
|
||||
- CC
|
||||
- cpp
|
||||
- Cpp
|
||||
- CPP
|
||||
- 'c++'
|
||||
- 'C++'
|
||||
CanonicalDelimiter: ''
|
||||
ReferenceAlignment: Middle
|
||||
ReflowComments: false # IndentOnly
|
||||
SeparateDefinitionBlocks: Always
|
||||
SortIncludes: CaseInsensitive
|
||||
SortUsingDeclarations: LexicographicNumeric
|
||||
SpaceAfterCStyleCast: true
|
||||
SpaceAfterLogicalNot: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCpp11BracedList: false
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceInEmptyBlock: false
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: Never
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInLineCommentPrefix:
|
||||
Minimum: 1
|
||||
Maximum: -1
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
SpaceBeforeSquareBrackets: false
|
||||
Standard: c++17
|
||||
TabWidth: 4
|
||||
UseTab: Never
|
||||
WhitespaceSensitiveMacros: ['STRINGIZE']
|
||||
...
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG MUSA_VERSION=rc3.1.0
|
||||
# Target the MUSA build image
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
|
||||
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements requirements
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -B build -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc) && \
|
||||
cp build/bin/* .
|
||||
|
||||
ENTRYPOINT ["/app/.devops/tools.sh"]
|
|
@ -877,6 +877,12 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
return iparams;
|
||||
}
|
||||
|
||||
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
|
||||
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
|
||||
llama_free_model(model);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
if (!params.control_vectors.empty()) {
|
||||
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
||||
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
|
||||
|
|
|
@ -3040,6 +3040,11 @@ class OlmoModel(Model):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Olmo1124ForCausalLM")
|
||||
class Olmo1124Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO_1124
|
||||
|
||||
|
||||
@Model.register("OlmoeForCausalLM")
|
||||
class OlmoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.OLMOE
|
||||
|
|
|
@ -252,6 +252,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
|
|||
}
|
||||
|
||||
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
|
||||
if (size == 0) {
|
||||
|
@ -266,6 +267,7 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
|
|||
}
|
||||
|
||||
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
GGML_ASSERT(tensor);
|
||||
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
|
||||
if (size == 0) {
|
||||
|
@ -689,7 +691,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
|
|||
}
|
||||
|
||||
static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
||||
if (buffer == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -724,8 +726,6 @@ static bool backend_prealloc_warn = false;
|
|||
|
||||
// returns the backend that should be used for the node based on the current locations
|
||||
static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
|
||||
// TODO: use supports_op to check if the backend supports the op
|
||||
|
||||
// assign pre-allocated nodes to their backend
|
||||
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
|
||||
if (cur_backend_id != -1) {
|
||||
|
@ -747,7 +747,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
|||
if(!backend_prealloc_warn)
|
||||
{
|
||||
backend_prealloc_warn = true;
|
||||
printf("\nCaution: pre-allocated tensor in a backend that cannot run the operation\n");
|
||||
printf("\nCaution: pre-allocated tensor (%s) in a backend that cannot run the operation\n", tensor->name);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2372,15 +2372,15 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
|
|||
// figure out which node we're on
|
||||
uint current_cpu;
|
||||
int getcpu_ret = 0;
|
||||
// #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
|
||||
//#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)
|
||||
// getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
|
||||
// #else
|
||||
//#else
|
||||
// // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
|
||||
// # if !defined(SYS_getcpu) && defined(SYS_get_cpu)
|
||||
// # define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
|
||||
// # endif
|
||||
//# if !defined(SYS_getcpu) && defined(SYS_get_cpu)
|
||||
//# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
|
||||
//# endif
|
||||
// getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node);
|
||||
// #endif
|
||||
//#endif
|
||||
// koboldcpp fix: we don't use numa and this thing breaks runpod
|
||||
|
||||
if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
|
||||
|
|
|
@ -1,699 +0,0 @@
|
|||
#include "dmmv.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#ifndef K_QUANTS_PER_ITERATION
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#else
|
||||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const block_q2_K * x = (const block_q2_K *)vx + ib0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
|
||||
const int q_offset = 32*im + l0;
|
||||
const int s_offset = 8*im;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
uint32_t aux[4];
|
||||
const uint8_t * d = (const uint8_t *)aux;
|
||||
const uint8_t * m = (const uint8_t *)(aux + 2);
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
|
||||
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
|
||||
aux[0] = a[0] & 0x0f0f0f0f;
|
||||
aux[1] = a[1] & 0x0f0f0f0f;
|
||||
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
|
||||
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
|
||||
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
|
||||
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
|
||||
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
|
||||
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
|
||||
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
|
||||
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
|
||||
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
|
||||
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
|
||||
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
|
||||
|
||||
}
|
||||
tmp += dall * sum1 - dmin * sum2;
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
const uint16_t kmask1 = 0x0303;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
|
||||
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0....15 or 0...7
|
||||
|
||||
const uint8_t m = 1 << (4*im);
|
||||
|
||||
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
|
||||
const int q_offset = 32*im + l0;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
uint16_t utmp[4];
|
||||
const int8_t * s = (const int8_t *)utmp;
|
||||
|
||||
const uint16_t s_shift = 4*im;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
const uint8_t * h = x[i].hmask + l0;
|
||||
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
|
||||
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
|
||||
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
|
||||
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
float sum = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
|
||||
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
|
||||
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
|
||||
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
|
||||
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
|
||||
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
|
||||
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
|
||||
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
|
||||
}
|
||||
tmp += d * sum;
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
|
||||
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||
|
||||
const int il = tid/step; // 0...3
|
||||
const int ir = tid - step*il; // 0...7 or 0...3
|
||||
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
|
||||
|
||||
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
const int in = il%2;
|
||||
|
||||
const int l0 = n*(2*ir + in);
|
||||
const int q_offset = 32*im + l0;
|
||||
const int y_offset = 64*im + l0;
|
||||
|
||||
uint16_t aux[4];
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
uint32_t q32[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q32;
|
||||
#else
|
||||
uint16_t q16[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q16;
|
||||
#endif
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y1 = yy + i*QK_K + y_offset;
|
||||
const float * y2 = y1 + 128;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
aux[0] = a[im+0] & kmask1;
|
||||
aux[1] = a[im+2] & kmask1;
|
||||
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
||||
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
|
||||
const uint32_t * q2 = q1 + 16;
|
||||
|
||||
q32[0] = q1[0] & 0x0f0f0f0f;
|
||||
q32[1] = q1[0] & 0xf0f0f0f0;
|
||||
q32[2] = q2[0] & 0x0f0f0f0f;
|
||||
q32[3] = q2[0] & 0xf0f0f0f0;
|
||||
|
||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||
float smin = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
|
||||
s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
|
||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||
}
|
||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||
#else
|
||||
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
|
||||
const uint16_t * q2 = q1 + 32;
|
||||
|
||||
q16[0] = q1[0] & 0x0f0f;
|
||||
q16[1] = q1[0] & 0xf0f0;
|
||||
q16[2] = q2[0] & 0x0f0f;
|
||||
q16[3] = q2[0] & 0xf0f0;
|
||||
|
||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||
float smin = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
|
||||
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
|
||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||
}
|
||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const block_q5_K * x = (const block_q5_K *)vx + ib0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
const uint16_t kmask1 = 0x3f3f;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid = threadIdx.x/2; // 0...15
|
||||
const int ix = threadIdx.x%2;
|
||||
|
||||
const int il = tid/4; // 0...3
|
||||
const int ir = tid - 4*il;// 0...3
|
||||
const int n = 2;
|
||||
|
||||
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
const int in = il%2;
|
||||
|
||||
const int l0 = n*(2*ir + in);
|
||||
const int q_offset = 32*im + l0;
|
||||
const int y_offset = 64*im + l0;
|
||||
|
||||
const uint8_t hm1 = 1 << (2*im);
|
||||
const uint8_t hm2 = hm1 << 4;
|
||||
|
||||
uint16_t aux[4];
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
|
||||
uint16_t q16[8];
|
||||
const uint8_t * q4 = (const uint8_t *)q16;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const uint8_t * ql1 = x[i].qs + q_offset;
|
||||
const uint8_t * qh = x[i].qh + l0;
|
||||
const float * y1 = yy + i*QK_K + y_offset;
|
||||
const float * y2 = y1 + 128;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
aux[0] = a[im+0] & kmask1;
|
||||
aux[1] = a[im+2] & kmask1;
|
||||
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
||||
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
||||
|
||||
float4 sum = {0.f, 0.f, 0.f, 0.f};
|
||||
float smin = 0;
|
||||
const uint16_t * q1 = (const uint16_t *)ql1;
|
||||
const uint16_t * q2 = q1 + 32;
|
||||
q16[0] = q1[0] & 0x0f0f;
|
||||
q16[1] = q1[8] & 0x0f0f;
|
||||
q16[2] = (q1[0] >> 4) & 0x0f0f;
|
||||
q16[3] = (q1[8] >> 4) & 0x0f0f;
|
||||
q16[4] = q2[0] & 0x0f0f;
|
||||
q16[5] = q2[8] & 0x0f0f;
|
||||
q16[6] = (q2[0] >> 4) & 0x0f0f;
|
||||
q16[7] = (q2[8] >> 4) & 0x0f0f;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
|
||||
+ y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
|
||||
sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
|
||||
+ y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
|
||||
sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
|
||||
+ y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
|
||||
sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
|
||||
+ y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
|
||||
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
|
||||
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
||||
}
|
||||
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
const int ql_offset = 64*im + l0;
|
||||
const int qh_offset = 32*im + l0;
|
||||
const int s_offset = 8*im + is;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * ql = x[i].ql + ql_offset;
|
||||
const uint8_t * qh = x[i].qh + qh_offset;
|
||||
const int8_t * s = x[i].scales + s_offset;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
||||
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
||||
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
||||
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
||||
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
||||
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
||||
tmp += sum;
|
||||
#else
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
||||
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
|
||||
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
|
||||
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp += sum;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const half * x = (const half *) vx;
|
||||
// load 2 halfs into register in a single instruction
|
||||
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = __low2float(x_reg);
|
||||
v.y = __high2float(x_reg);
|
||||
}
|
||||
|
||||
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
|
||||
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
|
||||
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
|
||||
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
|
||||
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
|
||||
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
|
||||
type == GGML_TYPE_F16 ? convert_f16 :
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
|
||||
constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
|
||||
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
|
||||
|
||||
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int iter_stride = 2*GGML_CUDA_DMMV_X;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// partial sum for each thread
|
||||
#ifdef GGML_CUDA_F16
|
||||
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
||||
#else
|
||||
float tmp = 0.0f;
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
for (int i = 0; i < ncols; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
|
||||
const int iqs = (col%qk)/qr; // x quant index
|
||||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// processing >2 values per i iter is faster for fast GPUs
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
// process 2 vals per j iter
|
||||
|
||||
// dequantize
|
||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||
dfloat2 v;
|
||||
dequantize_kernel(vx, ib, iqs + j/qr, v);
|
||||
|
||||
// matrix multiplication
|
||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||
#ifdef GGML_CUDA_F16
|
||||
if ( y_offset == 1 ) {
|
||||
// load 2 dfloats into register in a single instruction
|
||||
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||
tmp += __hmul2(v, y_reg);
|
||||
}
|
||||
else {
|
||||
tmp += __hmul2(v, {
|
||||
y[iybs + iqs + j/qr + 0],
|
||||
y[iybs + iqs + j/qr + y_offset]
|
||||
});
|
||||
}
|
||||
#else
|
||||
if ( y_offset == 1 ) {
|
||||
// load 2 dfloats into register in a single instruction
|
||||
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
|
||||
tmp += v.x * y_reg.x;
|
||||
tmp += v.y * y_reg.y;
|
||||
}
|
||||
else {
|
||||
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
||||
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
||||
}
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (tid == 0) {
|
||||
#ifdef GGML_CUDA_F16
|
||||
dst[row] = tmp.x + tmp.y;
|
||||
#else
|
||||
dst[row] = tmp;
|
||||
#endif // GGML_CUDA_F16
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const dim3 block_dims(32, 1, 1);
|
||||
dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<GGML_TYPE_F16>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
GGML_UNUSED(ctx);
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||
#ifdef GGML_CUDA_F16
|
||||
ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
|
||||
half * src1_dfloat = nullptr; // dfloat == half
|
||||
|
||||
bool src1_convert_f16 =
|
||||
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
||||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||
|
||||
if (src1_convert_f16) {
|
||||
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
|
||||
}
|
||||
#else
|
||||
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
}
|
||||
|
||||
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
|
||||
return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
|
||||
src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
|
||||
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
|
||||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
|
||||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
|
||||
src0_type == GGML_TYPE_F16;
|
||||
}
|
|
@ -18,11 +18,11 @@ bool g_mul_mat_q = false;
|
|||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/dmmv.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
#include "ggml-cuda/getrows.cuh"
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
#include "ggml-cuda/mmv.cuh"
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
|
@ -1021,114 +1021,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
|
|||
|
||||
#define MUL_MAT_SRC1_COL_STRIDE 128
|
||||
|
||||
static __global__ void mul_mat_p021_f16_f32(
|
||||
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
|
||||
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = channel / (nchannels_y / nchannels_x);
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
float tmp = 0.0f;
|
||||
|
||||
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
|
||||
const int col_x = col_x0 + threadIdx.x;
|
||||
|
||||
if (col_x >= ncols_x) {
|
||||
break;
|
||||
}
|
||||
|
||||
// x is transposed and permuted
|
||||
const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
|
||||
const float xi = __half2float(x[ix]);
|
||||
|
||||
const int row_y = col_x;
|
||||
|
||||
// y is not transposed but permuted
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
|
||||
tmp += xi * y[iy];
|
||||
}
|
||||
|
||||
// dst is not transposed and not permuted
|
||||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
|
||||
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
|
||||
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = channel / channel_x_divisor;
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
float tmp = 0.0f;
|
||||
|
||||
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
|
||||
const int col_x = col_x0 + threadIdx.x;
|
||||
|
||||
if (col_x >= ncols_x) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int row_y = col_x;
|
||||
|
||||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
|
||||
const float xi = __half2float(x[ix]);
|
||||
|
||||
tmp += xi * y[iy];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[idst] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_p021_f16_f32_cuda(
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_nums(1, nrows_x, nchannels_y);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
||||
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
|
||||
const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_nums(1, nrows_x, nchannels_y);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_cpy_tensor_2d(
|
||||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
|
@ -1655,58 +1547,6 @@ static void ggml_cuda_op_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
||||
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
||||
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
|
||||
void * src0_ddq = src0->data;
|
||||
float * src1_ddf = (float *) src1->data;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(!ggml_is_permuted(src0));
|
||||
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
|
||||
void * src0_ddq = src0->data;
|
||||
float * src1_ddf = (float *) src1->data;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
|
||||
const int64_t row_stride_x = nb01 / sizeof(half);
|
||||
const int64_t channel_stride_x = nb02 / sizeof(half);
|
||||
|
||||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||
}
|
||||
|
||||
static __global__ void k_compute_batched_ptrs(
|
||||
const half * src0_as_f16, const half * src1_as_f16, char * dst,
|
||||
const void ** ptrs_src, void ** ptrs_dst,
|
||||
|
@ -1880,21 +1720,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
|
||||
bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
// if mmvq is available it's a better choice than dmmv:
|
||||
#ifndef GGML_CUDA_FORCE_DMMV
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
#endif // GGML_CUDA_FORCE_DMMV
|
||||
|
||||
bool any_gpus_with_slow_fp16 = false;
|
||||
bool any_gpus_without_fp16_mma = false;
|
||||
|
||||
if (split) {
|
||||
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
|
@ -1908,11 +1744,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
}
|
||||
} else {
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
|
||||
}
|
||||
|
||||
// debug helpers
|
||||
|
@ -1923,18 +1761,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
|
||||
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
|
||||
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
|
||||
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
|
||||
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
|
||||
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
|
||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch without FlashAttention
|
||||
// general KQ + KQV multi-batch without FlashAttention
|
||||
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
|
||||
} else if (use_mul_mat_vec) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
||||
} else if (use_mul_mat_q) {
|
||||
|
|
223
ggml/src/ggml-cuda/mmv.cu
Normal file
223
ggml/src/ggml-cuda/mmv.cu
Normal file
|
@ -0,0 +1,223 @@
|
|||
#include "common.cuh"
|
||||
#include "mmv.cuh"
|
||||
|
||||
template <typename type_acc, int block_size>
|
||||
static __global__ void mul_mat_vec(
|
||||
const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
|
||||
y += channel *stride_channel_y;
|
||||
dst += channel *stride_channel_dst;
|
||||
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (tid < WARP_SIZE) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf;
|
||||
|
||||
if (std::is_same<type_acc, float>::value) {
|
||||
sumf = 0.0f;
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += tmpx.x * tmpy.x;
|
||||
sumf += tmpx.y * tmpy.y;
|
||||
}
|
||||
} else {
|
||||
#ifdef FP16_AVAILABLE
|
||||
half2 sumh2 = make_half2(0.0f, 0.0f);
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmp = y2[col2];
|
||||
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
||||
}
|
||||
|
||||
sumf = __low2float(sumh2) + __high2float(sumh2);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
sumf = warp_reduce_sum(sumf);
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
buf_iw[tid/WARP_SIZE] = sumf;
|
||||
__syncthreads();
|
||||
if (tid > WARP_SIZE) {
|
||||
return;
|
||||
}
|
||||
sumf = buf_iw[tid];
|
||||
sumf = warp_reduce_sum(sumf);
|
||||
}
|
||||
|
||||
if (tid != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = sumf;
|
||||
}
|
||||
|
||||
template <typename type_acc>
|
||||
static void launch_mul_mat_vec_cuda(
|
||||
const half * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(nchannels_y % nchannels_x == 0);
|
||||
const int64_t channel_ratio = nchannels_y / nchannels_x;
|
||||
|
||||
int64_t block_size_best = WARP_SIZE;
|
||||
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
|
||||
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
|
||||
if (niter < niter_best) {
|
||||
niter_best = niter;
|
||||
block_size_best = block_size;
|
||||
}
|
||||
}
|
||||
|
||||
const int smem = WARP_SIZE*sizeof(float);
|
||||
const dim3 block_nums(nrows, 1, nchannels_y);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_cuda(
|
||||
const half * x, const float * y, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
switch (prec) {
|
||||
case GGML_PREC_DEFAULT: {
|
||||
launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
||||
} break;
|
||||
case GGML_PREC_F32: {
|
||||
launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
||||
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
||||
GGML_ASSERT(src1->ne[1] == 1);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
GGML_ASSERT(dst->ne[2] == ne12);
|
||||
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[3] == 1);
|
||||
GGML_ASSERT( dst->ne[3] == 1);
|
||||
|
||||
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
|
||||
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
|
||||
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
|
||||
|
||||
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1_ncols == 1);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t nchannels_x = 1;
|
||||
const int64_t nchannels_y = 1;
|
||||
const int64_t channel_stride_x = 0;
|
||||
const int64_t channel_stride_y = 0;
|
||||
const int64_t channel_stride_dst = 0;
|
||||
|
||||
mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
||||
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
||||
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddq_i);
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
}
|
|
@ -1,20 +1,12 @@
|
|||
#include "common.cuh"
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
||||
#define MMV_MAX_ROWS 512
|
||||
|
||||
// TODO: remove this?
|
||||
#ifndef GGML_CUDA_DMMV_X
|
||||
#define GGML_CUDA_DMMV_X 32
|
||||
#endif
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
#ifndef GGML_CUDA_MMV_Y
|
||||
#define GGML_CUDA_MMV_Y 1
|
||||
#endif
|
||||
|
||||
void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
void ggml_cuda_op_mul_mat_vec(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);
|
|
@ -295,6 +295,9 @@ struct ggml_cgraph {
|
|||
enum ggml_cgraph_eval_order order;
|
||||
};
|
||||
|
||||
// returns a slice of cgraph with nodes [i0, i1)
|
||||
// the slice does not have leafs or gradients
|
||||
// if you need the gradients, get them from the original graph
|
||||
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
|
||||
|
||||
// Memory allocation
|
||||
|
|
249
ggml/src/ggml-metal/ggml-metal-impl.h
Normal file
249
ggml/src/ggml-metal/ggml-metal-impl.h
Normal file
|
@ -0,0 +1,249 @@
|
|||
#ifndef GGML_METAL_IMPL
|
||||
#define GGML_METAL_IMPL
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
// however, be careful from int overflows when using those in the kernel implementation
|
||||
//
|
||||
// - strides (e.g. nb00) use uint64_t
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
uint64_t offs;
|
||||
} ggml_metal_kargs_bin;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_cpy;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
int32_t n_past;
|
||||
int32_t n_dims;
|
||||
int32_t n_ctx_orig;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb_12_1;
|
||||
uint64_t nb_12_2;
|
||||
uint64_t nb_12_3;
|
||||
uint64_t nb31;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint16_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mv;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
uint64_t nb1;
|
||||
} ggml_metal_kargs_mul_mv_id;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
#import "ggml-metal-impl.h"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
|
@ -125,6 +126,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_ELU,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
|
@ -648,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
||||
|
@ -967,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
|
@ -1193,35 +1197,39 @@ static void ggml_metal_encode_node(
|
|||
|
||||
const int32_t dim = ((const int32_t *) dst->op_params)[0];
|
||||
|
||||
ggml_metal_kargs_concat args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.dim =*/ dim,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
|
@ -1239,8 +1247,6 @@ static void ggml_metal_encode_node(
|
|||
|
||||
bool bcast_row = false;
|
||||
|
||||
int64_t nb = ne00; // used by the "row" kernels
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
|
@ -1249,7 +1255,6 @@ static void ggml_metal_encode_node(
|
|||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
nb = ne00 / 4;
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
||||
|
@ -1269,36 +1274,39 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.offs =*/ offs,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
@ -1322,25 +1330,29 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_repeat args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||
|
||||
|
@ -1369,25 +1381,29 @@ static void ggml_metal_encode_node(
|
|||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
|
@ -1396,35 +1412,39 @@ static void ggml_metal_encode_node(
|
|||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ pnb1,
|
||||
/*.nb02 =*/ pnb2,
|
||||
/*.nb03 =*/ pnb3,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
/*.offs =*/ offs,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
|
@ -1572,6 +1592,18 @@ static void ggml_metal_encode_node(
|
|||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_ELU:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
@ -1640,6 +1672,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -1715,6 +1748,8 @@ static void ggml_metal_encode_node(
|
|||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
|
@ -1731,6 +1766,7 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
@ -1747,6 +1783,7 @@ static void ggml_metal_encode_node(
|
|||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
|
||||
}
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -1771,6 +1808,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -1841,6 +1879,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -1959,24 +1998,29 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
|
@ -2154,28 +2198,32 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
};
|
||||
|
||||
ggml_metal_kargs_mul_mv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
|
@ -2288,27 +2336,30 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
/*.nei1 =*/ ne21,
|
||||
/*.nbi1 =*/ nb21,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
||||
|
||||
|
@ -2467,30 +2518,34 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mv_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
/*.nei1 =*/ ne21,
|
||||
/*.nbi1 =*/ nb21,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.nb1 =*/ nb1,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
|
||||
const int64_t _ne1 = 1;
|
||||
const int tgz = dst_rows;
|
||||
|
@ -2563,6 +2618,7 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -2586,20 +2642,28 @@ static void ggml_metal_encode_node(
|
|||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < 1024) {
|
||||
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
@ -2624,6 +2688,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2641,22 +2706,35 @@ static void ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const int nth = MIN(256, ne00);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
|
@ -2706,40 +2784,44 @@ static void ggml_metal_encode_node(
|
|||
};
|
||||
}
|
||||
|
||||
ggml_metal_kargs_rope args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.n_past =*/ n_past,
|
||||
/*.n_dims =*/ n_dims,
|
||||
/*.n_ctx_orig =*/ n_ctx_orig,
|
||||
/*.freq_base =*/ freq_base,
|
||||
/*.freq_scale =*/ freq_scale,
|
||||
/*.ext_factor =*/ ext_factor,
|
||||
/*.attn_factor =*/ attn_factor,
|
||||
/*.beta_fast =*/ beta_fast,
|
||||
/*.beta_slow =*/ beta_slow,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
if (id_src2 != nil) {
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -2796,6 +2878,7 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2836,6 +2919,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2870,6 +2954,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2906,6 +2991,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
||||
|
@ -2927,6 +3013,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2965,6 +3052,7 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -2983,6 +3071,7 @@ static void ggml_metal_encode_node(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -3224,37 +3313,41 @@ static void ggml_metal_encode_node(
|
|||
}
|
||||
}
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext args = {
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne_12_2 =*/ ne12,
|
||||
/*.ne_12_3 =*/ ne13,
|
||||
/*.nb_12_1 =*/ nb11,
|
||||
/*.nb_12_2 =*/ nb12,
|
||||
/*.nb_12_3 =*/ nb13,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
if (id_src3) {
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
||||
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
|
@ -3389,25 +3482,29 @@ static void ggml_metal_encode_node(
|
|||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -3452,6 +3549,7 @@ static void ggml_metal_encode_node(
|
|||
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
||||
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
||||
|
||||
// TODO: add ggml_metal_kargs struct
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -14,51 +14,51 @@
|
|||
#include <vector>
|
||||
|
||||
struct ggml_opt_dataset {
|
||||
struct ggml_context * ctx;
|
||||
ggml_backend_buffer_t buf;
|
||||
struct ggml_tensor * data;
|
||||
struct ggml_tensor * labels;
|
||||
struct ggml_context * ctx = nullptr;
|
||||
ggml_backend_buffer_t buf = nullptr;
|
||||
struct ggml_tensor * data = nullptr;
|
||||
struct ggml_tensor * labels = nullptr;
|
||||
|
||||
int64_t ndata;
|
||||
int64_t ndata_shard;
|
||||
size_t nbs_data;
|
||||
size_t nbs_labels;
|
||||
int64_t ndata = -1;
|
||||
int64_t ndata_shard = -1;
|
||||
size_t nbs_data = -1;
|
||||
size_t nbs_labels = -1;
|
||||
|
||||
std::vector<int64_t> permutation;
|
||||
};
|
||||
|
||||
struct ggml_opt_context {
|
||||
ggml_backend_sched_t backend_sched;
|
||||
ggml_cgraph * allocated_graph;
|
||||
ggml_cgraph * allocated_graph_copy;
|
||||
struct ggml_context * ctx_static;
|
||||
struct ggml_context * ctx_static_cpu;
|
||||
struct ggml_context * ctx_compute;
|
||||
struct ggml_context * ctx_copy;
|
||||
ggml_backend_buffer_t buf_static;
|
||||
ggml_backend_buffer_t buf_static_cpu;
|
||||
ggml_backend_sched_t backend_sched = nullptr;
|
||||
ggml_cgraph * allocated_graph = nullptr;
|
||||
ggml_cgraph * allocated_graph_copy = nullptr;
|
||||
struct ggml_context * ctx_static = nullptr;
|
||||
struct ggml_context * ctx_static_cpu = nullptr;
|
||||
struct ggml_context * ctx_compute = nullptr;
|
||||
struct ggml_context * ctx_copy = nullptr;
|
||||
ggml_backend_buffer_t buf_static = nullptr;
|
||||
ggml_backend_buffer_t buf_static_cpu = nullptr;
|
||||
std::mt19937 rng;
|
||||
|
||||
struct ggml_tensor * inputs;
|
||||
struct ggml_tensor * outputs;
|
||||
struct ggml_tensor * labels;
|
||||
struct ggml_tensor * inputs = nullptr;
|
||||
struct ggml_tensor * outputs = nullptr;
|
||||
struct ggml_tensor * labels = nullptr;
|
||||
|
||||
struct ggml_tensor * loss;
|
||||
struct ggml_tensor * pred;
|
||||
struct ggml_tensor * ncorrect;
|
||||
struct ggml_tensor * loss = nullptr;
|
||||
struct ggml_tensor * pred = nullptr;
|
||||
struct ggml_tensor * ncorrect = nullptr;
|
||||
|
||||
struct ggml_cgraph * gf;
|
||||
struct ggml_cgraph * gb_grad;
|
||||
struct ggml_cgraph * gb_opt;
|
||||
struct ggml_cgraph * gf = nullptr;
|
||||
struct ggml_cgraph * gb_grad = nullptr;
|
||||
struct ggml_cgraph * gb_opt = nullptr;
|
||||
|
||||
int64_t iter;
|
||||
int32_t opt_period;
|
||||
int32_t opt_i;
|
||||
bool loss_per_datapoint;
|
||||
int64_t iter = 1;
|
||||
int32_t opt_period = 1;
|
||||
int32_t opt_i = 0;
|
||||
bool loss_per_datapoint = false;
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars;
|
||||
void * get_opt_pars_ud;
|
||||
struct ggml_tensor * adamw_params;
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
void * get_opt_pars_ud = nullptr;
|
||||
struct ggml_tensor * adamw_params = nullptr;
|
||||
};
|
||||
|
||||
struct ggml_opt_result {
|
||||
|
@ -67,8 +67,8 @@ struct ggml_opt_result {
|
|||
std::vector<int32_t> pred;
|
||||
int64_t ncorrect = 0;
|
||||
|
||||
bool loss_per_datapoint = false;
|
||||
int64_t opt_period = -1;
|
||||
bool loss_per_datapoint = false;
|
||||
};
|
||||
|
||||
// ====== Dataset ======
|
||||
|
@ -237,25 +237,33 @@ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_
|
|||
return new_tensor;
|
||||
}
|
||||
|
||||
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
|
||||
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
|
||||
std::map<ggml_tensor *, ggml_tensor *> tensor_map;
|
||||
|
||||
ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
|
||||
ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
|
||||
|
||||
for (int i = 0; i < graph->n_leafs; i++) {
|
||||
ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
|
||||
for (int i = 0; i < src->n_leafs; i++) {
|
||||
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
|
||||
}
|
||||
for (int i = 0; i < graph->n_nodes; i++) {
|
||||
ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
|
||||
GGML_ASSERT(dst->n_leafs == src->n_leafs);
|
||||
for (int i = 0; i < src->n_nodes; i++) {
|
||||
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
|
||||
}
|
||||
for (int i = 0; i < graph->n_nodes; ++i) {
|
||||
const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
|
||||
const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
|
||||
graph->grads[igrad_dst] = new_graph->grads[igrad_src];
|
||||
graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
|
||||
GGML_ASSERT(dst->n_nodes == src->n_nodes);
|
||||
for (int i = 0; i < src->n_nodes; ++i) {
|
||||
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
|
||||
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
|
||||
|
||||
GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
|
||||
GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
|
||||
|
||||
dst->grads[igrad_dst] = src->grads[igrad_src];
|
||||
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
|
||||
}
|
||||
|
||||
return new_graph;
|
||||
return dst;
|
||||
}
|
||||
|
||||
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
|
||||
|
@ -285,15 +293,10 @@ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph
|
|||
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
ggml_opt_context_t result = new struct ggml_opt_context;
|
||||
result->backend_sched = params.backend_sched;
|
||||
result->allocated_graph = nullptr;
|
||||
result->allocated_graph_copy = nullptr;
|
||||
result->ctx_compute = params.ctx_compute;
|
||||
result->ctx_copy = nullptr;
|
||||
result->inputs = params.inputs;
|
||||
result->outputs = params.outputs;
|
||||
result->iter = 1;
|
||||
result->opt_period = params.opt_period;
|
||||
result->opt_i = 0;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
|
||||
|
@ -348,7 +351,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
|
||||
switch (params.loss_type) {
|
||||
case GGML_OPT_LOSS_TYPE_MEAN: {
|
||||
result->labels = nullptr;
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
|
||||
|
@ -358,7 +360,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
break;
|
||||
}
|
||||
case GGML_OPT_LOSS_TYPE_SUM: {
|
||||
result->labels = nullptr;
|
||||
result->loss = ggml_sum(result->ctx_static, result->outputs);
|
||||
ggml_set_name(result->loss, "loss_sum");
|
||||
result->loss_per_datapoint = false;
|
||||
|
@ -413,14 +414,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
}
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
|
||||
result->gb_grad = nullptr;
|
||||
result->gb_opt = nullptr;
|
||||
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
result->buf_static_cpu = nullptr;
|
||||
|
||||
ggml_opt_alloc_graph(result, result->gf);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -429,14 +423,8 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
|
||||
|
||||
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
|
||||
result->gb_opt = nullptr;
|
||||
|
||||
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
|
||||
result->buf_static_cpu = nullptr;
|
||||
|
||||
ggml_opt_alloc_graph(result, result->gb_grad);
|
||||
ggml_graph_reset(result->gb_grad);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -466,7 +454,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
|||
|
||||
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
|
||||
|
||||
ggml_opt_alloc_graph(result, result->gb_opt);
|
||||
ggml_graph_reset(result->gb_opt);
|
||||
|
||||
return result;
|
||||
|
|
|
@ -4350,10 +4350,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (op->op == GGML_OP_MUL_MAT) {
|
||||
a = op->src[0];
|
||||
b = op->src[1];
|
||||
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
|
||||
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
a = op->src[2];
|
||||
b = op->src[1];
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -158,6 +158,7 @@ struct vk_device_struct {
|
|||
std::string name;
|
||||
uint64_t max_memory_allocation_size;
|
||||
bool fp16;
|
||||
bool pipeline_robustness;
|
||||
vk::Device device;
|
||||
uint32_t vendor_id;
|
||||
vk_queue compute_queue;
|
||||
|
@ -218,6 +219,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_tanh_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
||||
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
|
||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
|
||||
vk_pipeline pipeline_argsort_f32;
|
||||
|
@ -388,6 +390,7 @@ struct vk_op_soft_max_push_constants {
|
|||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
uint32_t nrows_x;
|
||||
};
|
||||
|
||||
struct vk_op_argsort_push_constants {
|
||||
|
@ -652,7 +655,7 @@ static uint32_t compile_count = 0;
|
|||
static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
|
||||
GGML_ASSERT(parameter_count > 0);
|
||||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
||||
|
@ -722,6 +725,15 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
|||
vk::PipelineCreateFlags(),
|
||||
pipeline_shader_create_info,
|
||||
pipeline->layout);
|
||||
|
||||
vk::PipelineRobustnessCreateInfoEXT rci;
|
||||
|
||||
if (device->pipeline_robustness && disable_robustness) {
|
||||
rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
|
||||
rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
|
||||
compute_pipeline_create_info.setPNext(&rci);
|
||||
}
|
||||
|
||||
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
|
||||
|
||||
{
|
||||
|
@ -1259,7 +1271,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
|
||||
std::vector<std::future<void>> compiles;
|
||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align) {
|
||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
|
||||
{
|
||||
// wait until fewer than N compiles are in progress
|
||||
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
|
||||
|
@ -1269,7 +1281,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
}
|
||||
compile_count++;
|
||||
}
|
||||
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
|
||||
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
|
||||
};
|
||||
|
||||
if (device->fp16) {
|
||||
|
@ -1368,45 +1380,45 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
// computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
|
||||
|
||||
// dequant shaders
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
@ -1497,8 +1509,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||
|
@ -1587,12 +1601,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
bool fp16_storage = false;
|
||||
bool fp16_compute = false;
|
||||
bool pipeline_robustness = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
||||
fp16_storage = true;
|
||||
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
|
||||
fp16_compute = true;
|
||||
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
|
||||
pipeline_robustness = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1638,10 +1655,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
|
||||
vk11_features.pNext = &vk12_features;
|
||||
|
||||
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
|
||||
pl_robustness_features.pNext = nullptr;
|
||||
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
|
||||
pl_robustness_features.pipelineRobustness = VK_FALSE;
|
||||
|
||||
if (pipeline_robustness) {
|
||||
vk12_features.pNext = &pl_robustness_features;
|
||||
device_extensions.push_back("VK_EXT_pipeline_robustness");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
||||
|
||||
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
|
||||
|
||||
if (!vk11_features.storageBuffer16BitAccess) {
|
||||
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
|
||||
throw std::runtime_error("Unsupported device");
|
||||
|
@ -1764,11 +1793,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|||
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %d = %s (%s) | uma: %d | fp16: %d | warp size: %d\n",
|
||||
idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n",
|
||||
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size);
|
||||
|
||||
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
||||
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1937,8 +1966,7 @@ void ggml_vk_instance_init() {
|
|||
vk_instance.device_indices.push_back(0);
|
||||
}
|
||||
}
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Found %d Vulkan devices:\n", vk_instance.device_indices.size());
|
||||
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
|
||||
|
||||
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
|
||||
ggml_vk_print_gpu_info(i);
|
||||
|
@ -3187,7 +3215,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
|
||||
if (ne01 > max_groups_x) {
|
||||
groups_z = 64;
|
||||
groups_x /= groups_z;
|
||||
groups_x = CEIL_DIV(groups_x, groups_z);
|
||||
}
|
||||
|
||||
// compute
|
||||
|
@ -3764,7 +3792,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
|
||||
if (ne01 > max_groups_x) {
|
||||
groups_z = 64;
|
||||
groups_x /= groups_z;
|
||||
groups_x = CEIL_DIV(groups_x, groups_z);
|
||||
}
|
||||
|
||||
// compute
|
||||
|
@ -3933,10 +3961,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_soft_max_f32;
|
||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||
}
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_soft_max_f32_f16;
|
||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ROPE:
|
||||
|
@ -4582,6 +4610,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
scale, max_bias,
|
||||
m0, m1,
|
||||
n_head_log2,
|
||||
nrows_x,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,15 @@
|
|||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
|
@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
|
@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2(vui & 0xF, vui >> 4) * d + m;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const float m = float(data_a_packed16[a_offset + ib].m);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
|
@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
|
@ -50,6 +78,15 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const float m = float(data_a_packed16[a_offset + ib].m);
|
||||
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
|
@ -57,6 +94,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const float d = float(data_a[a_offset + ib].d);
|
||||
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
|
||||
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
|
||||
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
|
@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
init_iq4nl_shmem();
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
const uint ir = tid%32;
|
||||
|
|
|
@ -12,6 +12,10 @@ void main() {
|
|||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#endif
|
||||
|
||||
if (i00 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -3,9 +3,7 @@
|
|||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#extension GL_EXT_null_initializer : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -14,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
|
||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||
#define K_PER_ITER 8
|
||||
#else
|
||||
#define K_PER_ITER 2
|
||||
#endif
|
||||
|
||||
|
||||
uint a_offset, b_offset, d_offset, y_offset;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
{
|
||||
const uint col = i*BLOCK_SIZE + 2*tid;
|
||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
|
||||
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
|
||||
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
|
||||
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
|
||||
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
|
||||
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
|
||||
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
|
||||
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
|
||||
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
|
||||
#else
|
||||
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
|
||||
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
|
||||
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
|
||||
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
|
||||
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
|
||||
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
|
||||
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
|
||||
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
|
||||
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
|
||||
#endif
|
||||
#else
|
||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||
// quantized formats since they'll be within the same block. We should
|
||||
|
@ -36,9 +66,24 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
if (!OOB) {
|
||||
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
||||
}
|
||||
#endif
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
|
||||
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
|
||||
#else
|
||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
|
@ -46,6 +91,7 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
if (!OOB) {
|
||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,24 +103,39 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
FLOAT_TYPE temp[NUM_ROWS] = {};
|
||||
FLOAT_TYPE temp[NUM_ROWS];
|
||||
|
||||
const int unroll_count = 8;
|
||||
for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||
temp[i] = FLOAT_TYPE(0);
|
||||
}
|
||||
|
||||
const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
|
||||
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i, false);
|
||||
i += 2;
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i, true);
|
||||
i += 2;
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
|
||||
i++;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -100,10 +161,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#endif
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
|
|
@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
|
|||
void main() {
|
||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||
|
||||
if (row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
|
|
|
@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
|
|||
void main() {
|
||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||
|
||||
if (row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
|
|
|
@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
|||
|
||||
shared FLOAT_TYPE tmp[32];
|
||||
|
||||
// Declare aliased versions of A and B bindings that can use 16b/32b loads for
|
||||
// the quantized values, and vec4 loads for B.
|
||||
struct block_q4_K_u32
|
||||
{
|
||||
f16vec2 d;
|
||||
uint32_t scales[3*QUANT_K/64/4];
|
||||
uint32_t qs[QUANT_K/2/4];
|
||||
};
|
||||
|
||||
struct block_q4_K_u16
|
||||
{
|
||||
f16vec2 d;
|
||||
uint16_t scales[3*QUANT_K/64/2];
|
||||
uint16_t qs[QUANT_K/2/2];
|
||||
};
|
||||
|
||||
layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
|
||||
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
|
||||
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||
|
||||
if (row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
|
@ -64,9 +48,9 @@ void main() {
|
|||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
|
||||
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
|
||||
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
|
||||
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
||||
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
||||
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
||||
|
@ -80,8 +64,8 @@ void main() {
|
|||
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
|
||||
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
|
||||
|
||||
uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
|
||||
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
|
||||
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
|
||||
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
|
||||
|
||||
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
|
||||
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
|
@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
|
|||
void main() {
|
||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||
|
||||
if (row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
|
@ -31,70 +37,106 @@ void main() {
|
|||
const uint8_t hm1 = uint8_t(1 << (2*v_im));
|
||||
const uint8_t hm2 = uint8_t(hm1 << 4);
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||
const uint y1_idx = i * QUANT_K + y_offset;
|
||||
const uint y2_idx = y1_idx + 128;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
|
||||
const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
|
||||
const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
|
||||
const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
|
||||
const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
|
||||
const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
|
||||
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
|
||||
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
|
||||
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
||||
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
||||
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
||||
|
||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
|
||||
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
|
||||
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
|
||||
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
|
||||
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
|
||||
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
|
||||
const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
|
||||
const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
|
||||
const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
|
||||
const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
|
||||
const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
|
||||
const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
|
||||
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
|
||||
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
|
||||
const uint32_t sc0 = ( scale0.x & 0x3f);
|
||||
const uint32_t sc1 = ( scale0.y & 0x3f);
|
||||
const uint32_t sc2 = ( scale4.x & 0x3f);
|
||||
const uint32_t sc3 = ( scale4.y & 0x3f);
|
||||
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
|
||||
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
|
||||
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
|
||||
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
|
||||
|
||||
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
||||
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
|
||||
|
||||
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
|
||||
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
|
||||
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
|
||||
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
|
||||
|
||||
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
|
||||
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
|
||||
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
|
||||
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
|
||||
|
||||
const uint32_t q4_0 = qs0_16_lo4.x;
|
||||
const uint32_t q4_1 = qs0_16_lo4.y;
|
||||
const uint32_t q4_2 = qs0_16_lo4.z;
|
||||
const uint32_t q4_3 = qs0_16_lo4.w;
|
||||
const uint32_t q4_4 = qs0_16_hi4.x;
|
||||
const uint32_t q4_5 = qs0_16_hi4.y;
|
||||
const uint32_t q4_6 = qs0_16_hi4.z;
|
||||
const uint32_t q4_7 = qs0_16_hi4.w;
|
||||
const uint32_t q4_8 = qs64_80_lo4.x;
|
||||
const uint32_t q4_9 = qs64_80_lo4.y;
|
||||
const uint32_t q4_10 = qs64_80_lo4.z;
|
||||
const uint32_t q4_11 = qs64_80_lo4.w;
|
||||
const uint32_t q4_12 = qs64_80_hi4.x;
|
||||
const uint32_t q4_13 = qs64_80_hi4.y;
|
||||
const uint32_t q4_14 = qs64_80_hi4.z;
|
||||
const uint32_t q4_15 = qs64_80_hi4.w;
|
||||
|
||||
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
|
||||
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
|
||||
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
|
||||
|
||||
uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
|
||||
uint32_t qh1 = qh0 >> 8;
|
||||
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
|
||||
uint32_t qh17 = qh16 >> 8;
|
||||
|
||||
const FLOAT_TYPE sx =
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
|
||||
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
|
||||
const FLOAT_TYPE sy =
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
|
||||
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
|
||||
const FLOAT_TYPE sz =
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
|
||||
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
|
||||
const FLOAT_TYPE sw =
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
|
||||
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
|
||||
const FLOAT_TYPE smin =
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
|
||||
(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
|
||||
const uint tmp_idx = 16 * ix + tid;
|
||||
tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
|
||||
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
|
||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||
temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
|
||||
}
|
||||
|
||||
tmp[gl_LocalInvocationID.x] = temp;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
|
@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
|
|||
void main() {
|
||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||
|
||||
if (row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
|
||||
|
@ -36,35 +42,60 @@ void main() {
|
|||
const uint s_offset = 8*v_im + is;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const uint tmp_idx = 16 * ix + tid;
|
||||
tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
|
||||
#else
|
||||
FLOAT_TYPE scales[4];
|
||||
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
|
||||
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
|
||||
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
|
||||
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
|
||||
|
||||
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
|
||||
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
|
||||
|
||||
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
|
||||
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
|
||||
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
|
||||
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
|
||||
|
||||
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
|
||||
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
|
||||
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
|
||||
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
|
||||
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
|
||||
|
||||
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
|
||||
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
|
||||
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
|
||||
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
|
||||
|
||||
uvec4 q0 = uvec4(unpack8(q0_u32));
|
||||
uvec4 q1 = uvec4(unpack8(q1_u32));
|
||||
uvec4 q2 = uvec4(unpack8(q2_u32));
|
||||
uvec4 q3 = uvec4(unpack8(q3_u32));
|
||||
|
||||
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
|
||||
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
|
||||
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
|
||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
|
||||
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
|
||||
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
|
||||
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
|
||||
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
|
||||
}
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#endif
|
||||
temp += sum * d;
|
||||
}
|
||||
|
||||
tmp[gl_LocalInvocationID.x] = temp;
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
||||
|
|
|
@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
|
|||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
#else
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
@ -11,14 +12,13 @@ layout (push_constant) uniform parameter
|
|||
float m0;
|
||||
float m1;
|
||||
uint n_head_log2;
|
||||
uint nrows_x;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||
|
@ -26,11 +26,18 @@ layout (binding = 2) buffer D {D_TYPE data_d[];};
|
|||
|
||||
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
|
||||
// over all the columns. The main function tries to pass a constant here,
|
||||
// as if it were a template function, to allow unrolling.
|
||||
void soft_max(uint num_iters) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint rowy = rowx % p.KY;
|
||||
|
||||
if (rowx >= p.nrows_x) {
|
||||
return;
|
||||
}
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
|
@ -46,19 +53,39 @@ void main() {
|
|||
// Find max
|
||||
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
|
||||
|
||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
||||
// Cache values while we compute the max, so we don't need to read them
|
||||
// again when we're ready to compute exp(x-max).
|
||||
const uint DATA_CACHE_SIZE = 16;
|
||||
FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
|
||||
|
||||
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
if (col >= p.KX) {
|
||||
break;
|
||||
FLOAT_TYPE a = FLOAT_TYPE(0);
|
||||
if (col < p.KX) {
|
||||
a = data_a[rowx * p.KX + col];
|
||||
}
|
||||
|
||||
max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
|
||||
FLOAT_TYPE b = FLOAT_TYPE(0);
|
||||
if (p.KY > 0 && col < p.KX) {
|
||||
b = data_b[rowy * p.KX + col];
|
||||
}
|
||||
|
||||
FLOAT_TYPE v = a * p.scale + slope * b;
|
||||
|
||||
if (col < p.KX) {
|
||||
max_val = max(max_val, v);
|
||||
}
|
||||
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
data_cache[idx] = v;
|
||||
}
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = max_val;
|
||||
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] = max(vals[tid], vals[tid + s]);
|
||||
}
|
||||
|
@ -68,39 +95,80 @@ void main() {
|
|||
max_val = vals[0];
|
||||
barrier();
|
||||
|
||||
// Sum up values
|
||||
vals[tid] = FLOAT_TYPE(0.0f);
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
|
||||
|
||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
||||
// Compute sum{exp(x - max)}
|
||||
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
if (col >= p.KX) {
|
||||
break;
|
||||
}
|
||||
|
||||
// compute exp(a*scale+b*slope), add it to sum, and cache the new value
|
||||
// in data_cache if possible.
|
||||
const uint i = rowx * p.KX + col;
|
||||
const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||
vals[tid] += val;
|
||||
FLOAT_TYPE val;
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
val = exp(data_cache[idx] - max_val);
|
||||
} else {
|
||||
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||
}
|
||||
sum += val;
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
data_cache[idx] = val;
|
||||
} else {
|
||||
data_d[i] = D_TYPE(val);
|
||||
}
|
||||
}
|
||||
|
||||
// reduce across the workgroup
|
||||
vals[tid] = sum;
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
vals[tid] += vals[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = vals[0];
|
||||
|
||||
const D_TYPE divisor = D_TYPE(vals[0]);
|
||||
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
||||
|
||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||
const uint col = col0 + tid;
|
||||
|
||||
if (col >= p.KX) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
data_d[rowx*p.KX + col] /= divisor;
|
||||
if (idx < DATA_CACHE_SIZE) {
|
||||
data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
|
||||
} else {
|
||||
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the soft_max function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
soft_max(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
soft_max(32);
|
||||
} else if (num_blocks > 8) {
|
||||
soft_max(16);
|
||||
} else if (num_blocks > 4) {
|
||||
soft_max(8);
|
||||
} else if (num_blocks == 4) {
|
||||
soft_max(4);
|
||||
} else if (num_blocks == 3) {
|
||||
soft_max(3);
|
||||
} else if (num_blocks == 2) {
|
||||
soft_max(2);
|
||||
} else if (num_blocks == 1) {
|
||||
soft_max(1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_TYPES_COMP)
|
||||
#define GGML_TYPES_COMP
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#define QUANT_K 1
|
||||
|
@ -38,8 +40,14 @@ struct block_q4_0
|
|||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
struct block_q4_0_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q4_0
|
||||
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
|
@ -54,7 +62,15 @@ struct block_q4_1
|
|||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
struct block_q4_1_packed16
|
||||
{
|
||||
float16_t d;
|
||||
float16_t m;
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q4_1
|
||||
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
|
@ -70,7 +86,15 @@ struct block_q5_0
|
|||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
struct block_q5_0_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qh[2];
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q5_0
|
||||
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
|
@ -87,7 +111,16 @@ struct block_q5_1
|
|||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
struct block_q5_1_packed16
|
||||
{
|
||||
float16_t d;
|
||||
float16_t m;
|
||||
uint qh;
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q5_1
|
||||
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
|
@ -100,8 +133,14 @@ struct block_q8_0
|
|||
float16_t d;
|
||||
int8_t qs[32];
|
||||
};
|
||||
struct block_q8_0_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[32/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q8_0
|
||||
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||
#endif
|
||||
|
||||
// K-quants
|
||||
|
@ -116,7 +155,23 @@ struct block_q2_K
|
|||
f16vec2 d;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed16
|
||||
{
|
||||
uint16_t scales[QUANT_K/16/2];
|
||||
uint16_t qs[QUANT_K/4/2];
|
||||
f16vec2 d;
|
||||
};
|
||||
|
||||
struct block_q2_K_packed32
|
||||
{
|
||||
uint32_t scales[QUANT_K/16/4];
|
||||
uint32_t qs[QUANT_K/4/4];
|
||||
f16vec2 d;
|
||||
};
|
||||
|
||||
#define A_TYPE block_q2_K
|
||||
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
|
@ -131,7 +186,16 @@ struct block_q3_K
|
|||
float16_t d;
|
||||
};
|
||||
|
||||
struct block_q3_K_packed16
|
||||
{
|
||||
uint16_t hmask[QUANT_K/8/2];
|
||||
uint16_t qs[QUANT_K/4/2];
|
||||
uint16_t scales[12/2];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
#define A_TYPE block_q3_K
|
||||
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
|
@ -145,7 +209,23 @@ struct block_q4_K
|
|||
uint8_t qs[QUANT_K/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
uint16_t scales[3*QUANT_K/64/2];
|
||||
uint16_t qs[QUANT_K/2/2];
|
||||
};
|
||||
|
||||
struct block_q4_K_packed32
|
||||
{
|
||||
f16vec2 d;
|
||||
uint32_t scales[3*QUANT_K/64/4];
|
||||
uint32_t qs[QUANT_K/2/4];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q4_K
|
||||
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_K)
|
||||
|
@ -160,7 +240,16 @@ struct block_q5_K
|
|||
uint8_t qs[QUANT_K/2];
|
||||
};
|
||||
|
||||
struct block_q5_K_packed16
|
||||
{
|
||||
f16vec2 d;
|
||||
uint16_t scales[12/2];
|
||||
uint16_t qh[QUANT_K/8/2];
|
||||
uint16_t qs[QUANT_K/2/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q5_K
|
||||
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
|
@ -175,7 +264,16 @@ struct block_q6_K
|
|||
float16_t d;
|
||||
};
|
||||
|
||||
struct block_q6_K_packed16
|
||||
{
|
||||
uint16_t ql[QUANT_K/2/2];
|
||||
uint16_t qh[QUANT_K/4/2];
|
||||
int8_t scales[QUANT_K/16];
|
||||
float16_t d;
|
||||
};
|
||||
|
||||
#define A_TYPE block_q6_K
|
||||
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||
#endif
|
||||
|
||||
// IQuants
|
||||
|
@ -191,10 +289,30 @@ struct block_iq4_nl
|
|||
uint8_t qs[QUANT_K/2];
|
||||
};
|
||||
|
||||
#define A_TYPE block_iq4_nl
|
||||
struct block_iq4_nl_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K/2/2];
|
||||
};
|
||||
|
||||
const int8_t kvalues_iq4nl[16] = {
|
||||
#define A_TYPE block_iq4_nl
|
||||
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||
|
||||
const int8_t kvalues_iq4nl_const[16] = {
|
||||
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
||||
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
|
||||
};
|
||||
|
||||
shared FLOAT_TYPE kvalues_iq4nl[16];
|
||||
|
||||
void init_iq4nl_shmem()
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
if (gl_LocalInvocationIndex.x < 16) {
|
||||
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // !defined(GGML_TYPES_COMP)
|
||||
|
|
|
@ -318,10 +318,10 @@ void process_shaders() {
|
|||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
// Dequant shaders
|
||||
if (tname != "f16") {
|
||||
|
@ -332,11 +332,11 @@ void process_shaders() {
|
|||
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||
} else {
|
||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||
}
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5020,8 +5020,10 @@ static void ggml_hash_map_free(struct hash_map * map) {
|
|||
}
|
||||
|
||||
// utility functions to change gradients
|
||||
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
|
||||
// else if a is in zero_table, replace a
|
||||
// isrc is the index of tensor in cgraph->visited_has_set.keys
|
||||
// the corresponding gradient (accumulators) are also at position isrc
|
||||
// if tensor has a gradient accumulator, modify that accumulator in-place
|
||||
// else if there is no gradient for tensor, set the corresponding value
|
||||
// else, just add/subtract/etc. the gradients
|
||||
|
||||
static void ggml_add_or_set(
|
||||
|
@ -5029,11 +5031,14 @@ static void ggml_add_or_set(
|
|||
struct ggml_cgraph * cgraph,
|
||||
size_t isrc,
|
||||
struct ggml_tensor * tensor) {
|
||||
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
|
||||
GGML_ASSERT(src);
|
||||
if (cgraph->grads[isrc]) {
|
||||
cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
|
||||
cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
|
||||
} else {
|
||||
cgraph->grads[isrc] = tensor;
|
||||
}
|
||||
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
|
||||
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
|
||||
}
|
||||
|
||||
|
@ -5041,18 +5046,20 @@ static void ggml_acc_or_set(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * cgraph,
|
||||
size_t isrc,
|
||||
struct ggml_tensor * src,
|
||||
struct ggml_tensor * tensor,
|
||||
const size_t nb1,
|
||||
const size_t nb2,
|
||||
const size_t nb3,
|
||||
const size_t offset) {
|
||||
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
|
||||
GGML_ASSERT(src);
|
||||
if (cgraph->grads[isrc]) {
|
||||
cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
|
||||
} else {
|
||||
struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
|
||||
cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
|
||||
}
|
||||
ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
|
||||
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
|
||||
}
|
||||
|
||||
|
@ -5060,13 +5067,15 @@ static void ggml_add1_or_set(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * cgraph,
|
||||
size_t isrc,
|
||||
struct ggml_tensor * src,
|
||||
struct ggml_tensor * tensor) {
|
||||
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
|
||||
GGML_ASSERT(src);
|
||||
if (cgraph->grads[isrc]) {
|
||||
cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
|
||||
} else {
|
||||
cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
|
||||
}
|
||||
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
|
||||
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
|
||||
}
|
||||
|
||||
|
@ -5075,11 +5084,14 @@ static void ggml_sub_or_set(
|
|||
struct ggml_cgraph * cgraph,
|
||||
size_t isrc,
|
||||
struct ggml_tensor * tensor) {
|
||||
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
|
||||
GGML_ASSERT(src);
|
||||
if (cgraph->grads[isrc]) {
|
||||
cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
|
||||
} else {
|
||||
cgraph->grads[isrc] = ggml_neg(ctx, tensor);
|
||||
}
|
||||
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
|
||||
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
|
||||
}
|
||||
|
||||
|
@ -5096,12 +5108,12 @@ static void ggml_compute_backward(
|
|||
struct ggml_tensor * src1 = tensor->src[1];
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
|
||||
const size_t isrc0 = ggml_hash_find(hash_set, src0);
|
||||
const size_t isrc1 = ggml_hash_find(hash_set, src1);
|
||||
const size_t isrc2 = ggml_hash_find(hash_set, src2);
|
||||
const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
|
||||
const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
|
||||
const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
|
||||
const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
|
||||
const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
|
||||
const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
|
||||
const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
|
||||
const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
|
||||
const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP: {
|
||||
|
@ -5201,7 +5213,7 @@ static void ggml_compute_backward(
|
|||
} break;
|
||||
case GGML_OP_SUM: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, grad);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS: {
|
||||
|
@ -5211,7 +5223,7 @@ static void ggml_compute_backward(
|
|||
} break;
|
||||
case GGML_OP_MEAN: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
||||
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT: {
|
||||
|
@ -5364,7 +5376,7 @@ static void ggml_compute_backward(
|
|||
nb3 = (nb3 / n0) * ng;
|
||||
}
|
||||
|
||||
ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
|
||||
ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_PERMUTE: {
|
||||
|
@ -5598,10 +5610,9 @@ void ggml_build_backward_expand(
|
|||
|
||||
const int n_nodes_f = cgraph->n_nodes;
|
||||
|
||||
const size_t hash_size = ggml_hash_size(2*cgraph->size);
|
||||
memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
|
||||
memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
|
||||
bool * grads_needed = calloc(hash_size, sizeof(bool));
|
||||
memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
|
||||
memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
|
||||
bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
|
||||
|
||||
{
|
||||
bool any_params = false;
|
||||
|
@ -5622,7 +5633,7 @@ void ggml_build_backward_expand(
|
|||
continue;
|
||||
}
|
||||
|
||||
bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
|
||||
bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
|
||||
bool ignore_src[GGML_MAX_SRC] = {false};
|
||||
switch (node->op) {
|
||||
// gradients in node->src[0] for one reason or another have no effect on output gradients
|
||||
|
@ -5666,9 +5677,12 @@ void ggml_build_backward_expand(
|
|||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
|
||||
|
||||
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
|
||||
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
|
||||
cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node);
|
||||
cgraph->grad_accs[igrad] = cgraph->grads[igrad];
|
||||
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
|
||||
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
|
||||
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
|
||||
}
|
||||
grads_needed[igrad] = true;
|
||||
}
|
||||
|
@ -5766,10 +5780,10 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
|
|||
/*.n_nodes =*/ i1 - i0,
|
||||
/*.n_leafs =*/ 0,
|
||||
/*.nodes =*/ cgraph0->nodes + i0,
|
||||
/*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
|
||||
/*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
|
||||
/*.grads =*/ NULL, // gradients would need visited_hash_set
|
||||
/*.grad_accs =*/ NULL,
|
||||
/*.leafs =*/ NULL,
|
||||
/*.hash_table =*/ { 0, NULL, NULL },
|
||||
/*.visited_hash_set =*/ { 0, NULL, NULL },
|
||||
/*.order =*/ cgraph0->order,
|
||||
};
|
||||
|
||||
|
@ -5800,12 +5814,22 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
|||
}
|
||||
}
|
||||
|
||||
if (dst->grads) {
|
||||
memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
|
||||
memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
|
||||
}
|
||||
if (src->grads) {
|
||||
GGML_ASSERT(dst->grads != NULL);
|
||||
GGML_ASSERT(dst->grad_accs != NULL);
|
||||
for (int i = 0; i < src->n_nodes; ++i) {
|
||||
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
|
||||
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
|
||||
|
||||
GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
|
||||
GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
|
||||
|
||||
dst->grads[igrad_dst] = src->grads[igrad_src];
|
||||
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
|
||||
}
|
||||
|
@ -5840,13 +5864,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|||
|
||||
if (node->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
// clear momenta
|
||||
if (node->src[2]->data) {
|
||||
ggml_set_zero(node->src[2]);
|
||||
}
|
||||
if (node->src[3]->data) {
|
||||
ggml_set_zero(node->src[3]);
|
||||
}
|
||||
}
|
||||
|
||||
// initial gradients of loss should be 1, 0 otherwise
|
||||
if (grad_acc) {
|
||||
|
|
|
@ -243,6 +243,7 @@ class MODEL_ARCH(IntEnum):
|
|||
COMMAND_R = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
OLMO_1124 = auto()
|
||||
OLMOE = auto()
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
|
@ -404,6 +405,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.OLMO_1124: "olmo_1124",
|
||||
MODEL_ARCH.OLMOE: "olmoe",
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
|
@ -1069,6 +1071,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.OLMO_1124: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.OLMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -13,7 +13,7 @@ class TensorNameMap:
|
|||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
|
@ -54,7 +54,7 @@ class TensorNameMap:
|
|||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo_1124
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
|
@ -66,7 +66,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo_1124
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
|
@ -145,7 +145,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo_1124
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
|
@ -157,7 +157,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo_1124
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
|
@ -170,7 +170,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo_1124
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
|
@ -188,7 +188,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo_1124
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
|
@ -215,7 +215,7 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo_1124
|
||||
),
|
||||
|
||||
# Rotary embeddings
|
||||
|
@ -250,7 +250,7 @@ class TensorNameMap:
|
|||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo_1124
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
|
@ -273,7 +273,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo_1124
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
|
@ -314,7 +314,7 @@ class TensorNameMap:
|
|||
|
||||
# Feed-forward gate
|
||||
MODEL_TENSOR.FFN_GATE: (
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo_1124
|
||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||
"transformer.h.{bid}.mlp.w2", # qwen
|
||||
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||
|
@ -346,7 +346,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo_1124
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
|
@ -383,7 +383,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo_1124
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
|
@ -392,7 +392,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo_1124
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
|
|
|
@ -669,6 +669,9 @@ extern "C" {
|
|||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||
|
||||
// Check if the context supports KV cache shifting
|
||||
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// State / sessions
|
||||
//
|
||||
|
|
208
src/llama.cpp
208
src/llama.cpp
|
@ -193,6 +193,7 @@ enum llm_arch {
|
|||
LLM_ARCH_COMMAND_R,
|
||||
LLM_ARCH_DBRX,
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_OLMO_1124,
|
||||
LLM_ARCH_OLMOE,
|
||||
LLM_ARCH_OPENELM,
|
||||
LLM_ARCH_ARCTIC,
|
||||
|
@ -246,6 +247,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||
{ LLM_ARCH_DBRX, "dbrx" },
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_OLMO_1124, "olmo_1124" },
|
||||
{ LLM_ARCH_OLMOE, "olmoe" },
|
||||
{ LLM_ARCH_OPENELM, "openelm" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
|
@ -1221,6 +1223,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_OLMO_1124,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_OLMOE,
|
||||
{
|
||||
|
@ -3481,21 +3502,13 @@ static bool llama_kv_cache_init(
|
|||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
|
||||
const llama_model::buft_list_t * buft_list;
|
||||
ggml_backend_buffer_type_t buft;
|
||||
if (offload) {
|
||||
buft_list = model.dev_layer.at(i).buft_list;
|
||||
auto * dev = model.dev_layer.at(i).dev;
|
||||
buft = ggml_backend_dev_buffer_type(dev);
|
||||
} else {
|
||||
buft_list = &model.cpu_buft_list;
|
||||
buft = ggml_backend_cpu_buffer_type();
|
||||
}
|
||||
ggml_backend_buffer_type_t buft = select_buft(*buft_list,
|
||||
[&](ggml_context * ctx) {
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
|
||||
return k;
|
||||
}
|
||||
ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
||||
return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
|
||||
});
|
||||
ggml_context * ctx = ctx_for_buft(buft);
|
||||
|
||||
if (!ctx) {
|
||||
|
@ -5917,6 +5930,17 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMO_1124:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 16: model.type = e_model::MODEL_1B; break;
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
@ -8698,6 +8722,31 @@ static bool llm_load_tensors(
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMO_1124:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
@ -14595,6 +14644,130 @@ struct llm_build_context {
|
|||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_olmo_1124() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_ffn(ctx0, lctx, ffn_inp,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "ffn_post_norm", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// based on the build_qwen2moe() function, changes:
|
||||
// * removed shared experts
|
||||
// * removed bias
|
||||
|
@ -16787,6 +16960,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_olmo();
|
||||
} break;
|
||||
case LLM_ARCH_OLMO_1124:
|
||||
{
|
||||
result = llm.build_olmo_1124();
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
result = llm.build_olmoe();
|
||||
|
@ -18199,7 +18376,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
|
||||
if (!llama_kv_cache_can_shift(&lctx)) {
|
||||
GGML_ABORT("Deepseek2 does not support K-shift");
|
||||
}
|
||||
|
||||
|
@ -20060,6 +20237,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_QWEN:
|
||||
case LLM_ARCH_QWEN2:
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
case LLM_ARCH_OLMO_1124:
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
case LLM_ARCH_PHI3:
|
||||
|
@ -20451,6 +20629,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
|
|||
llama_kv_cache_update_internal(*ctx);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
||||
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
|
||||
}
|
||||
|
||||
// deprecated
|
||||
size_t llama_get_state_size(struct llama_context * ctx) {
|
||||
return llama_state_get_size(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue