Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.devops/full-cuda.Dockerfile
#	.devops/llama-cli-cann.Dockerfile
#	.devops/llama-cli-cuda.Dockerfile
#	.devops/llama-cli-intel.Dockerfile
#	.devops/llama-cli-musa.Dockerfile
#	.devops/llama-cli-vulkan.Dockerfile
#	.devops/llama-server-cuda.Dockerfile
#	.devops/llama-server-intel.Dockerfile
#	.devops/llama-server-musa.Dockerfile
#	.devops/llama-server-vulkan.Dockerfile
#	.gitignore
#	CMakeLists.txt
#	Makefile
#	cmake/llama-config.cmake.in
#	docs/backend/SYCL.md
#	docs/build.md
#	examples/llama-bench/llama-bench.cpp
#	flake.lock
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	ggml/src/ggml-backend.cpp
#	ggml/src/ggml-blas/CMakeLists.txt
#	ggml/src/ggml-cpu/CMakeLists.txt
#	ggml/src/ggml-cpu/ggml-cpu.c
#	ggml/src/ggml-cuda/CMakeLists.txt
#	ggml/src/ggml-hip/CMakeLists.txt
#	ggml/src/ggml-metal/CMakeLists.txt
#	ggml/src/ggml-musa/CMakeLists.txt
#	ggml/src/ggml-sycl/CMakeLists.txt
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
This commit is contained in:
Concedo 2024-11-21 16:26:24 +08:00
commit 091a432cf6
38 changed files with 167475 additions and 136736 deletions

161
.clang-format Normal file
View file

@ -0,0 +1,161 @@
---
Language: Cpp
AlignAfterOpenBracket: Align
AlignArrayOfStructures: Left
AlignConsecutiveAssignments: AcrossComments
AlignConsecutiveBitFields: AcrossComments
AlignConsecutiveDeclarations: AcrossComments
AlignConsecutiveMacros: AcrossComments
# AlignConsecutiveShortCaseStatements: AcrossComments
AlignEscapedNewlines: Left # LeftWithLastLine
AlignOperands: Align
AlignTrailingComments:
Kind: Always
OverEmptyLines: 1
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: false
# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: true
BinPackParameters: true # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach
BraceWrapping:
AfterCaseLabel: true
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
# BreakAdjacentStringLiterals: true
BreakAfterAttributes: Never
BreakBeforeBinaryOperators: None
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: false
# BreakBinaryOperations: Never
BreakConstructorInitializers: AfterColon
# BreakFunctionDefinitionParameters: false
BreakInheritanceList: AfterComma
BreakStringLiterals: true
# BreakTemplateDeclarations: Yes
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: false
DerivePointerAlignment: false
DisableFormat: false
EmptyLineBeforeAccessModifier: Leave
EmptyLineAfterAccessModifier: Never
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
SortPriority: 0
- Regex: '^<.*'
Priority: 2
SortPriority: 0
- Regex: '.*'
Priority: 3
SortPriority: 0
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
IndentCaseBlocks: true
IndentCaseLabels: true
IndentExternBlock: NoIndent
IndentGotoLabels: false
IndentPPDirectives: AfterHash
IndentWidth: 4
IndentWrappedFunctionNames: false
InsertBraces: true # NOTE: may lead to incorrect formatting
InsertNewlineAtEOF: true
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
LambdaBodyIndentation: Signature
LineEnding: LF
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 4
ObjCSpaceAfterProperty: true
ObjCSpaceBeforeProtocolList: true
PPIndentWidth: -1
PackConstructorInitializers: CurrentLine
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Middle
QualifierAlignment: Left
#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
ReferenceAlignment: Middle
ReflowComments: false # IndentOnly
SeparateDefinitionBlocks: Always
SortIncludes: CaseInsensitive
SortUsingDeclarations: LexicographicNumeric
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyBlock: false
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: Never
SpacesInContainerLiterals: true
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpaceBeforeSquareBrackets: false
Standard: c++17
TabWidth: 4
UseTab: Never
WhitespaceSensitiveMacros: ['STRINGIZE']
...

View file

@ -1,26 +0,0 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc3.1.0
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
RUN apt-get update && \
apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1
COPY requirements.txt requirements.txt
COPY requirements requirements
RUN pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt
WORKDIR /app
COPY . .
RUN cmake -B build -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
cmake --build build --config Release -j$(nproc) && \
cp build/bin/* .
ENTRYPOINT ["/app/.devops/tools.sh"]

View file

@ -877,6 +877,12 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
llama_free_model(model);
return iparams;
}
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);

View file

@ -3040,6 +3040,11 @@ class OlmoModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Olmo1124ForCausalLM")
class Olmo1124Model(Model):
model_arch = gguf.MODEL_ARCH.OLMO_1124
@Model.register("OlmoeForCausalLM")
class OlmoeModel(Model):
model_arch = gguf.MODEL_ARCH.OLMOE

View file

@ -252,6 +252,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
}
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (size == 0) {
@ -266,6 +267,7 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
}
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (size == 0) {
@ -689,7 +691,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
}
static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (buffer == NULL) {
return -1;
}
@ -724,8 +726,6 @@ static bool backend_prealloc_warn = false;
// returns the backend that should be used for the node based on the current locations
static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
// TODO: use supports_op to check if the backend supports the op
// assign pre-allocated nodes to their backend
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
if (cur_backend_id != -1) {
@ -747,7 +747,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
if(!backend_prealloc_warn)
{
backend_prealloc_warn = true;
printf("\nCaution: pre-allocated tensor in a backend that cannot run the operation\n");
printf("\nCaution: pre-allocated tensor (%s) in a backend that cannot run the operation\n", tensor->name);
}
}

View file

@ -2372,7 +2372,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
// figure out which node we're on
uint current_cpu;
int getcpu_ret = 0;
// #if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
//#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)
// getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
//#else
// // old glibc doesn't have a wrapper for this call. Fall back on direct syscall

View file

@ -1,699 +0,0 @@
#include "dmmv.cuh"
#include "dequantize.cuh"
#include "convert.cuh"
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q2_K * x = (const block_q2_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
aux[1] = a[1] & 0x0f0f0f0f;
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
}
tmp += dall * sum1 - dmin * sum2;
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7
const uint8_t m = 1 << (4*im);
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;
uint16_t utmp[4];
const int8_t * s = (const int8_t *)utmp;
const uint16_t s_shift = 4*im;
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const uint8_t * h = x[i].hmask + l0;
const uint16_t * a = (const uint16_t *)x[i].scales;
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
const float d = x[i].d;
float sum = 0;
for (int l = 0; l < n; ++l) {
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
}
tmp += d * sum;
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q4_K * x = (const block_q4_K *)vx + ib0;
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16;
q32[0] = q1[0] & 0x0f0f0f0f;
q32[1] = q1[0] & 0xf0f0f0f0;
q32[2] = q2[0] & 0x0f0f0f0f;
q32[3] = q2[0] & 0xf0f0f0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 4; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q5_K * x = (const block_q5_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 2;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
const uint8_t hm1 = 1 << (2*im);
const uint8_t hm2 = hm1 << 4;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
uint16_t q16[8];
const uint8_t * q4 = (const uint8_t *)q16;
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
const uint16_t * q1 = (const uint16_t *)ql1;
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[8] & 0x0f0f;
q16[2] = (q1[0] >> 4) & 0x0f0f;
q16[3] = (q1[8] >> 4) & 0x0f0f;
q16[4] = q2[0] & 0x0f0f;
q16[5] = q2[8] & 0x0f0f;
q16[6] = (q2[0] >> 4) & 0x0f0f;
q16[7] = (q2[8] >> 4) & 0x0f0f;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
const uint8_t * qh = x[i].qh + qh_offset;
const int8_t * s = x[i].scales + s_offset;
const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
}
}
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;
// load 2 halfs into register in a single instruction
const half2 x_reg = *((half2 *) &(x[ib + iqs]));
// automatic half -> float type cast if dfloat == float
v.x = __low2float(x_reg);
v.y = __high2float(x_reg);
}
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
type == GGML_TYPE_F16 ? convert_f16 :
nullptr;
}
template <ggml_type type>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;
// partial sum for each thread
#ifdef GGML_CUDA_F16
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
#else
float tmp = 0.0f;
#endif // GGML_CUDA_F16
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter
// dequantize
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
dfloat2 v;
dequantize_kernel(vx, ib, iqs + j/qr, v);
// matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_F16
if ( y_offset == 1 ) {
// load 2 dfloats into register in a single instruction
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
tmp += __hmul2(v, y_reg);
}
else {
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
}
#else
if ( y_offset == 1 ) {
// load 2 dfloats into register in a single instruction
const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
tmp += v.x * y_reg.x;
tmp += v.y * y_reg.y;
}
else {
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
}
#endif // GGML_CUDA_F16
}
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
#ifdef GGML_CUDA_F16
dst[row] = tmp.x + tmp.y;
#else
dst[row] = tmp;
#endif // GGML_CUDA_F16
}
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1);
dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<GGML_TYPE_F16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
void ggml_cuda_op_dequantize_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_UNUSED(ctx);
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
}
#else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_CUDA_F16
switch (src0->type) {
case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_K:
dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
src0_type == GGML_TYPE_F16;
}

View file

@ -18,11 +18,11 @@ bool g_mul_mat_q = false;
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
#include "ggml-cuda/mmv.cuh"
#include "ggml-cuda/mmvq.cuh"
#include "ggml-cuda/norm.cuh"
#include "ggml-cuda/opt-step-adamw.cuh"
@ -1021,114 +1021,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
#define MUL_MAT_SRC1_COL_STRIDE 128
static __global__ void mul_mat_p021_f16_f32(
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / (nchannels_y / nchannels_x);
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
float tmp = 0.0f;
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
break;
}
// x is transposed and permuted
const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
// y is not transposed but permuted
const int iy = channel*nrows_y + row_y;
tmp += xi * y[iy];
}
// dst is not transposed and not permuted
const int idst = channel*nrows_dst + row_dst;
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[idst] = tmp;
}
}
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
const int idst = channel*nrows_dst + row_dst;
float tmp = 0.0f;
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
break;
}
const int row_y = col_x;
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const int iy = channel*nrows_y + row_y;
const float xi = __half2float(x[ix]);
tmp += xi * y[iy];
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[idst] = tmp;
}
}
static void ggml_mul_mat_p021_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
}
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
}
static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
@ -1655,58 +1547,6 @@ static void ggml_cuda_op_mul_mat(
}
}
static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
cudaStream_t main_stream = ctx.stream();
void * src0_ddq = src0->data;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
cudaStream_t main_stream = ctx.stream();
void * src0_ddq = src0->data;
float * src1_ddf = (float *) src1->data;
float * dst_ddf = (float *) dst->data;
const int64_t row_stride_x = nb01 / sizeof(half);
const int64_t channel_stride_x = nb02 / sizeof(half);
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}
static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
@ -1880,21 +1720,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
// if mmvq is available it's a better choice than dmmv:
#ifndef GGML_CUDA_FORCE_DMMV
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
#endif // GGML_CUDA_FORCE_DMMV
bool any_gpus_with_slow_fp16 = false;
bool any_gpus_without_fp16_mma = false;
if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@ -1908,11 +1744,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
}
// debug helpers
@ -1923,18 +1761,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// FP32 precision KQ single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// FP32 precision KQV single-batch for batch size 1 without FlashAttention
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// KQ + KQV multi-batch without FlashAttention
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
} else if (use_mul_mat_q) {

223
ggml/src/ggml-cuda/mmv.cu Normal file
View file

@ -0,0 +1,223 @@
#include "common.cuh"
#include "mmv.cuh"
template <typename type_acc, int block_size>
static __global__ void mul_mat_vec(
const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += channel *stride_channel_y;
dst += channel *stride_channel_dst;
const half2 * x2 = (const half2 *) x;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;
if (block_size > WARP_SIZE) {
if (tid < WARP_SIZE) {
buf_iw[tid] = 0.0f;
}
__syncthreads();
}
float sumf;
if (std::is_same<type_acc, float>::value) {
sumf = 0.0f;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
const float2 tmpy = y2[col2];
sumf += tmpx.x * tmpy.x;
sumf += tmpx.y * tmpy.y;
}
} else {
#ifdef FP16_AVAILABLE
half2 sumh2 = make_half2(0.0f, 0.0f);
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmp = y2[col2];
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
}
sumf = __low2float(sumh2) + __high2float(sumh2);
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
sumf = warp_reduce_sum(sumf);
if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf;
__syncthreads();
if (tid > WARP_SIZE) {
return;
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum(sumf);
}
if (tid != 0) {
return;
}
dst[row] = sumf;
}
template <typename type_acc>
static void launch_mul_mat_vec_cuda(
const half * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x;
int64_t block_size_best = WARP_SIZE;
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
block_size_best = block_size;
}
}
const int smem = WARP_SIZE*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 64: {
mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 96: {
mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 128: {
mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 160: {
mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 192: {
mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 224: {
mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 256: {
mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
static void mul_mat_vec_cuda(
const half * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
enum ggml_prec prec, cudaStream_t stream) {
switch (prec) {
case GGML_PREC_DEFAULT: {
launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
} break;
case GGML_PREC_F32: {
launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
} break;
}
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
GGML_ASSERT(src1->ne[1] == 1);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
const half * src0_d = (const half *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
GGML_ASSERT(dst->ne[2] == ne12);
GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(src1->ne[3] == 1);
GGML_ASSERT( dst->ne[3] == 1);
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
}
void ggml_cuda_op_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1_ncols == 1);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00;
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t channel_stride_x = 0;
const int64_t channel_stride_y = 0;
const int64_t channel_stride_dst = 0;
mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
GGML_UNUSED(ctx);
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}

View file

@ -1,20 +1,12 @@
#include "common.cuh"
// dmmv = dequantize_mul_mat_vec
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
#define MMV_MAX_ROWS 512
// TODO: remove this?
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
#ifndef GGML_CUDA_MMV_Y
#define GGML_CUDA_MMV_Y 1
#endif
void ggml_cuda_op_dequantize_mul_mat_vec(
void ggml_cuda_op_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

View file

@ -295,6 +295,9 @@ struct ggml_cgraph {
enum ggml_cgraph_eval_order order;
};
// returns a slice of cgraph with nodes [i0, i1)
// the slice does not have leafs or gradients
// if you need the gradients, get them from the original graph
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
// Memory allocation

View file

@ -0,0 +1,249 @@
#ifndef GGML_METAL_IMPL
#define GGML_METAL_IMPL
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
// however, be careful from int overflows when using those in the kernel implementation
//
// - strides (e.g. nb00) use uint64_t
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
} ggml_metal_kargs_bin;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_cpy;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_past;
int32_t n_dims;
int32_t n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mv;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
} ggml_metal_kargs_mul_mm_id;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_rms_norm;
#endif // GGML_METAL_IMPL

View file

@ -2,6 +2,7 @@
#import "ggml-impl.h"
#import "ggml-backend-impl.h"
#import "ggml-metal-impl.h"
#import <Foundation/Foundation.h>
@ -125,6 +126,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
GGML_METAL_KERNEL_TYPE_SILU,
GGML_METAL_KERNEL_TYPE_SILU_4,
GGML_METAL_KERNEL_TYPE_ELU,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@ -648,6 +650,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@ -967,6 +970,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_ELU:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -1193,35 +1197,39 @@ static void ggml_metal_encode_node(
const int32_t dim = ((const int32_t *) dst->op_params)[0];
ggml_metal_kargs_concat args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.dim =*/ dim,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
const int nth = MIN(1024, ne0);
@ -1239,8 +1247,6 @@ static void ggml_metal_encode_node(
bool bcast_row = false;
int64_t nb = ne00; // used by the "row" kernels
id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@ -1249,7 +1255,6 @@ static void ggml_metal_encode_node(
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
switch (dst->op) {
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
@ -1269,36 +1274,39 @@ static void ggml_metal_encode_node(
}
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.offs =*/ offs,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
if (bcast_row) {
const int64_t n = ggml_nelements(dst)/4;
@ -1322,25 +1330,29 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error");
}
ggml_metal_kargs_repeat args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
@ -1369,25 +1381,29 @@ static void ggml_metal_encode_node(
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
ggml_metal_kargs_cpy args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@ -1396,35 +1412,39 @@ static void ggml_metal_encode_node(
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
/*.nb03 =*/ pnb3,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
/*.offs =*/ offs,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@ -1572,6 +1592,18 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_ELU:
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
{
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
@ -1640,6 +1672,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1715,6 +1748,8 @@ static void ggml_metal_encode_node(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// TODO: add ggml_metal_kargs struct
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
@ -1731,6 +1766,7 @@ static void ggml_metal_encode_node(
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@ -1747,6 +1783,7 @@ static void ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
}
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -1771,6 +1808,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -1841,6 +1879,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -1959,24 +1998,29 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("MUL MAT-MAT not implemented");
}
ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
/*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
@ -2154,28 +2198,32 @@ static void ggml_metal_encode_node(
}
};
ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
@ -2288,27 +2336,30 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("MUL_MAT_ID not implemented");
}
ggml_metal_kargs_mul_mm_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
/*.nbi1 =*/ nb21,
/*.ne00 =*/ ne00,
/*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
@ -2467,30 +2518,34 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne00 >= nth0*nth1);
}
ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
/*.nbi1 =*/ nb21,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb1 =*/ nb1,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
const int64_t _ne1 = 1;
const int tgz = dst_rows;
@ -2563,6 +2618,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
}
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@ -2586,20 +2642,28 @@ static void ggml_metal_encode_node(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < 1024) {
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
nth = MIN(nth, ne00/4);
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
@ -2624,6 +2688,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2641,22 +2706,35 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = MIN(256, ne00);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00/4);
ggml_metal_kargs_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
@ -2706,40 +2784,44 @@ static void ggml_metal_encode_node(
};
}
ggml_metal_kargs_rope args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.n_past =*/ n_past,
/*.n_dims =*/ n_dims,
/*.n_ctx_orig =*/ n_ctx_orig,
/*.freq_base =*/ freq_base,
/*.freq_scale =*/ freq_scale,
/*.ext_factor =*/ ext_factor,
/*.attn_factor =*/ attn_factor,
/*.beta_fast =*/ beta_fast,
/*.beta_slow =*/ beta_slow,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
if (id_src2 != nil) {
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
[encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@ -2796,6 +2878,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error");
};
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2836,6 +2919,7 @@ static void ggml_metal_encode_node(
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2870,6 +2954,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2906,6 +2991,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
@ -2927,6 +3013,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2965,6 +3052,7 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("fatal error");
};
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -2983,6 +3071,7 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@ -3224,37 +3313,41 @@ static void ggml_metal_encode_node(
}
}
ggml_metal_kargs_flash_attn_ext args = {
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne11 =*/ ne11,
/*.ne_12_2 =*/ ne12,
/*.ne_12_3 =*/ ne13,
/*.nb_12_1 =*/ nb11,
/*.nb_12_2 =*/ nb12,
/*.nb_12_3 =*/ nb13,
/*.nb31 =*/ nb31,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.n_head_log2 =*/ n_head_log2,
/*.logit_softcap =*/ logit_softcap,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
if (id_src3) {
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
if (!use_vec_kernel) {
// half8x8 kernel
@ -3389,25 +3482,29 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
}
ggml_metal_kargs_cpy args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@ -3452,6 +3549,7 @@ static void ggml_metal_encode_node(
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
// TODO: add ggml_metal_kargs struct
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

File diff suppressed because it is too large Load diff

View file

@ -14,51 +14,51 @@
#include <vector>
struct ggml_opt_dataset {
struct ggml_context * ctx;
ggml_backend_buffer_t buf;
struct ggml_tensor * data;
struct ggml_tensor * labels;
struct ggml_context * ctx = nullptr;
ggml_backend_buffer_t buf = nullptr;
struct ggml_tensor * data = nullptr;
struct ggml_tensor * labels = nullptr;
int64_t ndata;
int64_t ndata_shard;
size_t nbs_data;
size_t nbs_labels;
int64_t ndata = -1;
int64_t ndata_shard = -1;
size_t nbs_data = -1;
size_t nbs_labels = -1;
std::vector<int64_t> permutation;
};
struct ggml_opt_context {
ggml_backend_sched_t backend_sched;
ggml_cgraph * allocated_graph;
ggml_cgraph * allocated_graph_copy;
struct ggml_context * ctx_static;
struct ggml_context * ctx_static_cpu;
struct ggml_context * ctx_compute;
struct ggml_context * ctx_copy;
ggml_backend_buffer_t buf_static;
ggml_backend_buffer_t buf_static_cpu;
ggml_backend_sched_t backend_sched = nullptr;
ggml_cgraph * allocated_graph = nullptr;
ggml_cgraph * allocated_graph_copy = nullptr;
struct ggml_context * ctx_static = nullptr;
struct ggml_context * ctx_static_cpu = nullptr;
struct ggml_context * ctx_compute = nullptr;
struct ggml_context * ctx_copy = nullptr;
ggml_backend_buffer_t buf_static = nullptr;
ggml_backend_buffer_t buf_static_cpu = nullptr;
std::mt19937 rng;
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
struct ggml_tensor * labels;
struct ggml_tensor * inputs = nullptr;
struct ggml_tensor * outputs = nullptr;
struct ggml_tensor * labels = nullptr;
struct ggml_tensor * loss;
struct ggml_tensor * pred;
struct ggml_tensor * ncorrect;
struct ggml_tensor * loss = nullptr;
struct ggml_tensor * pred = nullptr;
struct ggml_tensor * ncorrect = nullptr;
struct ggml_cgraph * gf;
struct ggml_cgraph * gb_grad;
struct ggml_cgraph * gb_opt;
struct ggml_cgraph * gf = nullptr;
struct ggml_cgraph * gb_grad = nullptr;
struct ggml_cgraph * gb_opt = nullptr;
int64_t iter;
int32_t opt_period;
int32_t opt_i;
bool loss_per_datapoint;
int64_t iter = 1;
int32_t opt_period = 1;
int32_t opt_i = 0;
bool loss_per_datapoint = false;
ggml_opt_get_optimizer_params get_opt_pars;
void * get_opt_pars_ud;
struct ggml_tensor * adamw_params;
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
void * get_opt_pars_ud = nullptr;
struct ggml_tensor * adamw_params = nullptr;
};
struct ggml_opt_result {
@ -67,8 +67,8 @@ struct ggml_opt_result {
std::vector<int32_t> pred;
int64_t ncorrect = 0;
bool loss_per_datapoint = false;
int64_t opt_period = -1;
bool loss_per_datapoint = false;
};
// ====== Dataset ======
@ -237,25 +237,33 @@ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_
return new_tensor;
}
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
std::map<ggml_tensor *, ggml_tensor *> tensor_map;
ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
for (int i = 0; i < graph->n_leafs; i++) {
ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
for (int i = 0; i < src->n_leafs; i++) {
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
}
for (int i = 0; i < graph->n_nodes; i++) {
ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
GGML_ASSERT(dst->n_leafs == src->n_leafs);
for (int i = 0; i < src->n_nodes; i++) {
ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
}
for (int i = 0; i < graph->n_nodes; ++i) {
const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
graph->grads[igrad_dst] = new_graph->grads[igrad_src];
graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
GGML_ASSERT(dst->n_nodes == src->n_nodes);
for (int i = 0; i < src->n_nodes; ++i) {
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
dst->grads[igrad_dst] = src->grads[igrad_src];
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
}
return new_graph;
return dst;
}
static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
@ -285,15 +293,10 @@ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph
ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_opt_context_t result = new struct ggml_opt_context;
result->backend_sched = params.backend_sched;
result->allocated_graph = nullptr;
result->allocated_graph_copy = nullptr;
result->ctx_compute = params.ctx_compute;
result->ctx_copy = nullptr;
result->inputs = params.inputs;
result->outputs = params.outputs;
result->iter = 1;
result->opt_period = params.opt_period;
result->opt_i = 0;
result->get_opt_pars = params.get_opt_pars;
result->get_opt_pars_ud = params.get_opt_pars_ud;
@ -348,7 +351,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
switch (params.loss_type) {
case GGML_OPT_LOSS_TYPE_MEAN: {
result->labels = nullptr;
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
@ -358,7 +360,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
break;
}
case GGML_OPT_LOSS_TYPE_SUM: {
result->labels = nullptr;
result->loss = ggml_sum(result->ctx_static, result->outputs);
ggml_set_name(result->loss, "loss_sum");
result->loss_per_datapoint = false;
@ -413,14 +414,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
}
if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
result->gb_grad = nullptr;
result->gb_opt = nullptr;
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
result->buf_static_cpu = nullptr;
ggml_opt_alloc_graph(result, result->gf);
return result;
}
@ -429,14 +423,8 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
result->gb_opt = nullptr;
result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
result->buf_static_cpu = nullptr;
ggml_opt_alloc_graph(result, result->gb_grad);
ggml_graph_reset(result->gb_grad);
return result;
}
@ -466,7 +454,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
ggml_opt_alloc_graph(result, result->gb_opt);
ggml_graph_reset(result->gb_opt);
return result;

View file

@ -4350,10 +4350,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
return false;
}
} else {
a = op->src[2];
b = op->src[1];

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -158,6 +158,7 @@ struct vk_device_struct {
std::string name;
uint64_t max_memory_allocation_size;
bool fp16;
bool pipeline_robustness;
vk::Device device;
uint32_t vendor_id;
vk_queue compute_queue;
@ -218,6 +219,7 @@ struct vk_device_struct {
vk_pipeline pipeline_tanh_f32;
vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_argsort_f32;
@ -388,6 +390,7 @@ struct vk_op_soft_max_push_constants {
float m0;
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
};
struct vk_op_argsort_push_constants {
@ -652,7 +655,7 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
@ -722,6 +725,15 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
vk::PipelineCreateFlags(),
pipeline_shader_create_info,
pipeline->layout);
vk::PipelineRobustnessCreateInfoEXT rci;
if (device->pipeline_robustness && disable_robustness) {
rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
compute_pipeline_create_info.setPNext(&rci);
}
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
{
@ -1259,7 +1271,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align) {
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@ -1269,7 +1281,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
};
if (device->fp16) {
@ -1368,45 +1380,45 @@ static void ggml_vk_load_shaders(vk_device& device) {
// computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@ -1497,8 +1509,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@ -1587,12 +1601,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool fp16_storage = false;
bool fp16_compute = false;
bool pipeline_robustness = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true;
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
pipeline_robustness = true;
}
}
@ -1638,10 +1655,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
vk11_features.pNext = &vk12_features;
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
pl_robustness_features.pNext = nullptr;
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
pl_robustness_features.pipelineRobustness = VK_FALSE;
if (pipeline_robustness) {
vk12_features.pNext = &pl_robustness_features;
device_extensions.push_back("VK_EXT_pipeline_robustness");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
if (!vk11_features.storageBuffer16BitAccess) {
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
throw std::runtime_error("Unsupported device");
@ -1764,11 +1793,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
fp16 = fp16 && vk12_features.shaderFloat16;
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %d = %s (%s) | uma: %d | fp16: %d | warp size: %d\n",
idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size);
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
}
}
@ -1937,8 +1966,7 @@ void ggml_vk_instance_init() {
vk_instance.device_indices.push_back(0);
}
}
GGML_LOG_DEBUG("ggml_vulkan: Found %d Vulkan devices:\n", vk_instance.device_indices.size());
GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
ggml_vk_print_gpu_info(i);
@ -3187,7 +3215,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ne01 > max_groups_x) {
groups_z = 64;
groups_x /= groups_z;
groups_x = CEIL_DIV(groups_x, groups_z);
}
// compute
@ -3764,7 +3792,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (ne01 > max_groups_x) {
groups_z = 64;
groups_x /= groups_z;
groups_x = CEIL_DIV(groups_x, groups_z);
}
// compute
@ -3933,10 +3961,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32;
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_soft_max_f32_f16;
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
}
return nullptr;
case GGML_OP_ROPE:
@ -4582,6 +4610,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
scale, max_bias,
m0, m1,
n_head_log2,
nrows_x,
}, dryrun);
}

View file

@ -2,6 +2,15 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#include "types.comp"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
}
#endif
#if defined(DATA_A_Q4_1)
@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
}
#endif
#if defined(DATA_A_Q5_0)
@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
}
#endif
#if defined(DATA_A_Q5_1)
@ -50,6 +78,15 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const float m = float(data_a_packed16[a_offset + ib].m);
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
}
#endif
#if defined(DATA_A_Q8_0)
@ -57,6 +94,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a[a_offset + ib].d);
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
}
#endif
#if defined(DATA_A_IQ4_NL)
@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const float d = float(data_a_packed16[a_offset + ib].d);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
}
#endif

View file

@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq4nl_shmem();
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;

View file

@ -12,6 +12,10 @@ void main() {
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
if (i00 >= p.ne00) {
return;
}

View file

@ -3,9 +3,7 @@
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_null_initializer : enable
#extension GL_EXT_shader_explicit_arithmetic_types : require
#include "mul_mat_vec_base.comp"
@ -14,16 +12,48 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif
uint a_offset, b_offset, d_offset, y_offset;
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
const uint col = i*BLOCK_SIZE + 2*tid;
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
#else
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
@ -36,9 +66,24 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
if (!OOB) {
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
}
#endif
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
#if K_PER_ITER == 8
const vec4 v = dequantize4(ib, iqs, a_offset);
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
// matrix multiplication
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
#else
const vec2 v = dequantize(ib, iqs, a_offset);
// matrix multiplication
@ -46,6 +91,7 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
if (!OOB) {
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
}
#endif
}
}
@ -57,24 +103,39 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
FLOAT_TYPE temp[NUM_ROWS] = {};
FLOAT_TYPE temp[NUM_ROWS];
const int unroll_count = 8;
for (uint i = 0; i < NUM_ROWS; ++i) {
temp[i] = FLOAT_TYPE(0);
}
const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i, false);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i, true);
i += 2;
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}
// sum up partial sums and write back result
@ -100,10 +161,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View file

@ -12,6 +12,9 @@
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};

View file

@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
if (row >= p.stride_d) {
return;
}
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

View file

@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
if (row >= p.stride_d) {
return;
}
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);

View file

@ -8,30 +8,14 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
// Declare aliased versions of A and B bindings that can use 16b/32b loads for
// the quantized values, and vec4 loads for B.
struct block_q4_K_u32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};
struct block_q4_K_u16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};
layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
if (row >= p.stride_d) {
return;
}
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -64,9 +48,9 @@ void main() {
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
@ -80,8 +64,8 @@ void main() {
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;

View file

@ -1,5 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
if (row >= p.stride_d) {
return;
}
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -31,70 +37,106 @@ void main() {
const uint8_t hm1 = uint8_t(1 << (2*v_im));
const uint8_t hm2 = uint8_t(hm1 << 4);
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
f16vec2 d = data_a[ib0 + i].d;
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
const uint32_t sc0 = ( scale0.x & 0x3f);
const uint32_t sc1 = ( scale0.y & 0x3f);
const uint32_t sc2 = ( scale4.x & 0x3f);
const uint32_t sc3 = ( scale4.y & 0x3f);
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
const uint32_t q4_0 = qs0_16_lo4.x;
const uint32_t q4_1 = qs0_16_lo4.y;
const uint32_t q4_2 = qs0_16_lo4.z;
const uint32_t q4_3 = qs0_16_lo4.w;
const uint32_t q4_4 = qs0_16_hi4.x;
const uint32_t q4_5 = qs0_16_hi4.y;
const uint32_t q4_6 = qs0_16_hi4.z;
const uint32_t q4_7 = qs0_16_hi4.w;
const uint32_t q4_8 = qs64_80_lo4.x;
const uint32_t q4_9 = qs64_80_lo4.y;
const uint32_t q4_10 = qs64_80_lo4.z;
const uint32_t q4_11 = qs64_80_lo4.w;
const uint32_t q4_12 = qs64_80_hi4.x;
const uint32_t q4_13 = qs64_80_hi4.y;
const uint32_t q4_14 = qs64_80_hi4.z;
const uint32_t q4_15 = qs64_80_hi4.w;
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
uint32_t qh1 = qh0 >> 8;
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
uint32_t qh17 = qh16 >> 8;
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
}
tmp[gl_LocalInvocationID.x] = temp;
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {

View file

@ -1,5 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
if (row >= p.stride_d) {
return;
}
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
@ -36,35 +42,60 @@ void main() {
const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
#if K_QUANTS_PER_ITERATION == 1
const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
#else
FLOAT_TYPE scales[4];
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
uvec4 q0 = uvec4(unpack8(q0_u32));
uvec4 q1 = uvec4(unpack8(q1_u32));
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) {
sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
}
tmp[16 * ix + tid] += sum;
#endif
temp += sum * d;
}
tmp[gl_LocalInvocationID.x] = temp;
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {

View file

@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
#endif
void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else

View file

@ -1,6 +1,7 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
@ -11,14 +12,13 @@ layout (push_constant) uniform parameter
float m0;
float m1;
uint n_head_log2;
uint nrows_x;
} p;
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
@ -26,11 +26,18 @@ layout (binding = 2) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
void main() {
// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
// over all the columns. The main function tries to pass a constant here,
// as if it were a template function, to allow unrolling.
void soft_max(uint num_iters) {
const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint rowy = rowx % p.KY;
if (rowx >= p.nrows_x) {
return;
}
float slope = 1.0f;
// ALiBi
@ -46,19 +53,39 @@ void main() {
// Find max
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
// Cache values while we compute the max, so we don't need to read them
// again when we're ready to compute exp(x-max).
const uint DATA_CACHE_SIZE = 16;
FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
FLOAT_TYPE a = FLOAT_TYPE(0);
if (col < p.KX) {
a = data_a[rowx * p.KX + col];
}
max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
FLOAT_TYPE b = FLOAT_TYPE(0);
if (p.KY > 0 && col < p.KX) {
b = data_b[rowy * p.KX + col];
}
FLOAT_TYPE v = a * p.scale + slope * b;
if (col < p.KX) {
max_val = max(max_val, v);
}
if (idx < DATA_CACHE_SIZE) {
data_cache[idx] = v;
}
}
// reduce across the workgroup
vals[tid] = max_val;
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] = max(vals[tid], vals[tid + s]);
}
@ -68,39 +95,80 @@ void main() {
max_val = vals[0];
barrier();
// Sum up values
vals[tid] = FLOAT_TYPE(0.0f);
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
// Compute sum{exp(x - max)}
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
}
// compute exp(a*scale+b*slope), add it to sum, and cache the new value
// in data_cache if possible.
const uint i = rowx * p.KX + col;
const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
vals[tid] += val;
FLOAT_TYPE val;
if (idx < DATA_CACHE_SIZE) {
val = exp(data_cache[idx] - max_val);
} else {
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
}
sum += val;
if (idx < DATA_CACHE_SIZE) {
data_cache[idx] = val;
} else {
data_d[i] = D_TYPE(val);
}
}
// reduce across the workgroup
vals[tid] = sum;
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
vals[tid] += vals[tid + s];
}
barrier();
}
sum = vals[0];
const D_TYPE divisor = D_TYPE(vals[0]);
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
const uint col = col0 + tid;
if (col >= p.KX) {
break;
continue;
}
data_d[rowx*p.KX + col] /= divisor;
if (idx < DATA_CACHE_SIZE) {
data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
} else {
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
}
}
}
void main() {
// instantiate the soft_max function for several different
// dimensions, to allow loop unrolling
uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
if (num_blocks > 32) {
soft_max(num_blocks);
} else if (num_blocks > 16) {
soft_max(32);
} else if (num_blocks > 8) {
soft_max(16);
} else if (num_blocks > 4) {
soft_max(8);
} else if (num_blocks == 4) {
soft_max(4);
} else if (num_blocks == 3) {
soft_max(3);
} else if (num_blocks == 2) {
soft_max(2);
} else if (num_blocks == 1) {
soft_max(1);
}
}

View file

@ -1,6 +1,8 @@
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP
#extension GL_EXT_shader_explicit_arithmetic_types : require
#if defined(DATA_A_F32)
#define QUANT_K 1
@ -38,8 +40,14 @@ struct block_q4_0
float16_t d;
uint8_t qs[16];
};
struct block_q4_0_packed16
{
float16_t d;
uint16_t qs[16/2];
};
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#endif
#if defined(DATA_A_Q4_1)
@ -54,7 +62,15 @@ struct block_q4_1
uint8_t qs[16];
};
struct block_q4_1_packed16
{
float16_t d;
float16_t m;
uint16_t qs[16/2];
};
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#endif
#if defined(DATA_A_Q5_0)
@ -70,7 +86,15 @@ struct block_q5_0
uint8_t qs[16];
};
struct block_q5_0_packed16
{
float16_t d;
uint16_t qh[2];
uint16_t qs[16/2];
};
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#endif
#if defined(DATA_A_Q5_1)
@ -87,7 +111,16 @@ struct block_q5_1
uint8_t qs[16];
};
struct block_q5_1_packed16
{
float16_t d;
float16_t m;
uint qh;
uint16_t qs[16/2];
};
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#endif
#if defined(DATA_A_Q8_0)
@ -100,8 +133,14 @@ struct block_q8_0
float16_t d;
int8_t qs[32];
};
struct block_q8_0_packed16
{
float16_t d;
uint16_t qs[32/2];
};
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#endif
// K-quants
@ -116,7 +155,23 @@ struct block_q2_K
f16vec2 d;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K/16/2];
uint16_t qs[QUANT_K/4/2];
f16vec2 d;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K/16/4];
uint32_t qs[QUANT_K/4/4];
f16vec2 d;
};
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#endif
#if defined(DATA_A_Q3_K)
@ -131,7 +186,16 @@ struct block_q3_K
float16_t d;
};
struct block_q3_K_packed16
{
uint16_t hmask[QUANT_K/8/2];
uint16_t qs[QUANT_K/4/2];
uint16_t scales[12/2];
float16_t d;
};
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#endif
#if defined(DATA_A_Q4_K)
@ -145,7 +209,23 @@ struct block_q4_K
uint8_t qs[QUANT_K/2];
};
struct block_q4_K_packed16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};
struct block_q4_K_packed32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#endif
#if defined(DATA_A_Q5_K)
@ -160,7 +240,16 @@ struct block_q5_K
uint8_t qs[QUANT_K/2];
};
struct block_q5_K_packed16
{
f16vec2 d;
uint16_t scales[12/2];
uint16_t qh[QUANT_K/8/2];
uint16_t qs[QUANT_K/2/2];
};
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#endif
#if defined(DATA_A_Q6_K)
@ -175,7 +264,16 @@ struct block_q6_K
float16_t d;
};
struct block_q6_K_packed16
{
uint16_t ql[QUANT_K/2/2];
uint16_t qh[QUANT_K/4/2];
int8_t scales[QUANT_K/16];
float16_t d;
};
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#endif
// IQuants
@ -191,10 +289,30 @@ struct block_iq4_nl
uint8_t qs[QUANT_K/2];
};
#define A_TYPE block_iq4_nl
struct block_iq4_nl_packed16
{
float16_t d;
uint16_t qs[QUANT_K/2/2];
};
const int8_t kvalues_iq4nl[16] = {
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
};
shared FLOAT_TYPE kvalues_iq4nl[16];
void init_iq4nl_shmem()
{
// copy the table into shared memory and sync
if (gl_LocalInvocationIndex.x < 16) {
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
}
barrier();
}
#endif
#endif // !defined(GGML_TYPES_COMP)

View file

@ -318,10 +318,10 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
// Dequant shaders
if (tname != "f16") {
@ -332,11 +332,11 @@ void process_shaders() {
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
} else {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
}
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
}
}

View file

@ -5020,8 +5020,10 @@ static void ggml_hash_map_free(struct hash_map * map) {
}
// utility functions to change gradients
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
// else if a is in zero_table, replace a
// isrc is the index of tensor in cgraph->visited_has_set.keys
// the corresponding gradient (accumulators) are also at position isrc
// if tensor has a gradient accumulator, modify that accumulator in-place
// else if there is no gradient for tensor, set the corresponding value
// else, just add/subtract/etc. the gradients
static void ggml_add_or_set(
@ -5029,11 +5031,14 @@ static void ggml_add_or_set(
struct ggml_cgraph * cgraph,
size_t isrc,
struct ggml_tensor * tensor) {
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
GGML_ASSERT(src);
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
} else {
cgraph->grads[isrc] = tensor;
}
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
@ -5041,18 +5046,20 @@ static void ggml_acc_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
size_t isrc,
struct ggml_tensor * src,
struct ggml_tensor * tensor,
const size_t nb1,
const size_t nb2,
const size_t nb3,
const size_t offset) {
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
GGML_ASSERT(src);
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
} else {
struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
}
ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
@ -5060,13 +5067,15 @@ static void ggml_add1_or_set(
struct ggml_context * ctx,
struct ggml_cgraph * cgraph,
size_t isrc,
struct ggml_tensor * src,
struct ggml_tensor * tensor) {
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
GGML_ASSERT(src);
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
} else {
cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
}
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
@ -5075,11 +5084,14 @@ static void ggml_sub_or_set(
struct ggml_cgraph * cgraph,
size_t isrc,
struct ggml_tensor * tensor) {
struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
GGML_ASSERT(src);
if (cgraph->grads[isrc]) {
cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
} else {
cgraph->grads[isrc] = ggml_neg(ctx, tensor);
}
ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
@ -5096,12 +5108,12 @@ static void ggml_compute_backward(
struct ggml_tensor * src1 = tensor->src[1];
struct ggml_tensor * src2 = tensor->src[2];
struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
const size_t isrc0 = ggml_hash_find(hash_set, src0);
const size_t isrc1 = ggml_hash_find(hash_set, src1);
const size_t isrc2 = ggml_hash_find(hash_set, src2);
const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
switch (tensor->op) {
case GGML_OP_DUP: {
@ -5201,7 +5213,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_SUM: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
ggml_add1_or_set(ctx, cgraph, isrc0, grad);
}
} break;
case GGML_OP_SUM_ROWS: {
@ -5211,7 +5223,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_MEAN: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
}
} break;
case GGML_OP_REPEAT: {
@ -5364,7 +5376,7 @@ static void ggml_compute_backward(
nb3 = (nb3 / n0) * ng;
}
ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
}
} break;
case GGML_OP_PERMUTE: {
@ -5598,10 +5610,9 @@ void ggml_build_backward_expand(
const int n_nodes_f = cgraph->n_nodes;
const size_t hash_size = ggml_hash_size(2*cgraph->size);
memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
bool * grads_needed = calloc(hash_size, sizeof(bool));
memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
{
bool any_params = false;
@ -5622,7 +5633,7 @@ void ggml_build_backward_expand(
continue;
}
bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
bool ignore_src[GGML_MAX_SRC] = {false};
switch (node->op) {
// gradients in node->src[0] for one reason or another have no effect on output gradients
@ -5666,9 +5677,12 @@ void ggml_build_backward_expand(
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
GGML_ASSERT(igrad != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node);
cgraph->grad_accs[igrad] = cgraph->grads[igrad];
cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
cgraph->grads[igrad] = cgraph->grad_accs[igrad];
ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
}
grads_needed[igrad] = true;
}
@ -5766,10 +5780,10 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
/*.n_nodes =*/ i1 - i0,
/*.n_leafs =*/ 0,
/*.nodes =*/ cgraph0->nodes + i0,
/*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
/*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
/*.grads =*/ NULL, // gradients would need visited_hash_set
/*.grad_accs =*/ NULL,
/*.leafs =*/ NULL,
/*.hash_table =*/ { 0, NULL, NULL },
/*.visited_hash_set =*/ { 0, NULL, NULL },
/*.order =*/ cgraph0->order,
};
@ -5800,12 +5814,22 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
}
}
if (dst->grads) {
memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
}
if (src->grads) {
GGML_ASSERT(dst->grads != NULL);
GGML_ASSERT(dst->grad_accs != NULL);
for (int i = 0; i < src->n_nodes; ++i) {
const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
dst->grads[igrad_dst] = src->grads[igrad_src];
dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
}
@ -5840,13 +5864,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
if (node->op == GGML_OP_OPT_STEP_ADAMW) {
// clear momenta
if (node->src[2]->data) {
ggml_set_zero(node->src[2]);
}
if (node->src[3]->data) {
ggml_set_zero(node->src[3]);
}
}
// initial gradients of loss should be 1, 0 otherwise
if (grad_acc) {

View file

@ -243,6 +243,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
OLMO_1124 = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
@ -404,6 +405,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMO_1124: "olmo_1124",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
@ -1069,6 +1071,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.OLMO_1124: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.OLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron olmoe
"model.embed_tokens", # llama-hf nemotron olmoe olmo_1124
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@ -54,7 +54,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo_1124
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@ -66,7 +66,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2 olmoe
"model.norm", # llama-hf baichuan internlm2 olmoe olmo_1124
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
@ -145,7 +145,7 @@ class TensorNameMap:
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
@ -157,7 +157,7 @@ class TensorNameMap:
# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
@ -170,7 +170,7 @@ class TensorNameMap:
# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
@ -188,7 +188,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo_1124
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
@ -215,7 +215,7 @@ class TensorNameMap:
),
MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo_1124
),
# Rotary embeddings
@ -250,7 +250,7 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo_1124
),
MODEL_TENSOR.FFN_GATE_INP: (
@ -273,7 +273,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo_1124
"layers.{bid}.feed_forward.w3", # llama-pth
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
@ -314,7 +314,7 @@ class TensorNameMap:
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo_1124
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
@ -346,7 +346,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo_1124
"layers.{bid}.feed_forward.w2", # llama-pth
"encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j
@ -383,7 +383,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo_1124
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
@ -392,7 +392,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo_1124
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm

View file

@ -669,6 +669,9 @@ extern "C" {
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
//
// State / sessions
//

View file

@ -193,6 +193,7 @@ enum llm_arch {
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
LLM_ARCH_OLMO_1124,
LLM_ARCH_OLMOE,
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
@ -246,6 +247,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_OLMO_1124, "olmo_1124" },
{ LLM_ARCH_OLMOE, "olmoe" },
{ LLM_ARCH_OPENELM, "openelm" },
{ LLM_ARCH_ARCTIC, "arctic" },
@ -1221,6 +1223,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_OLMO_1124,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_OLMOE,
{
@ -3481,21 +3502,13 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const llama_model::buft_list_t * buft_list;
ggml_backend_buffer_type_t buft;
if (offload) {
buft_list = model.dev_layer.at(i).buft_list;
auto * dev = model.dev_layer.at(i).dev;
buft = ggml_backend_dev_buffer_type(dev);
} else {
buft_list = &model.cpu_buft_list;
buft = ggml_backend_cpu_buffer_type();
}
ggml_backend_buffer_type_t buft = select_buft(*buft_list,
[&](ggml_context * ctx) {
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
return k;
}
ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
});
ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) {
@ -5917,6 +5930,17 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OLMO_1124:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 16: model.type = e_model::MODEL_1B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OLMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -8698,6 +8722,31 @@ static bool llm_load_tensors(
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_OLMO_1124:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_OLMOE:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -14595,6 +14644,130 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_olmo_1124() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
cur = inpL;
// self_attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur_rope", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_ffn(ctx0, lctx, ffn_inp,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post_norm", -1);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
// based on the build_qwen2moe() function, changes:
// * removed shared experts
// * removed bias
@ -16787,6 +16960,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_olmo();
} break;
case LLM_ARCH_OLMO_1124:
{
result = llm.build_olmo_1124();
} break;
case LLM_ARCH_OLMOE:
{
result = llm.build_olmoe();
@ -18199,7 +18376,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
if (!llama_kv_cache_can_shift(&lctx)) {
GGML_ABORT("Deepseek2 does not support K-shift");
}
@ -20060,6 +20237,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_OLMO_1124:
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
@ -20451,6 +20629,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
llama_kv_cache_update_internal(*ctx);
}
bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
}
// deprecated
size_t llama_get_state_size(struct llama_context * ctx) {
return llama_state_get_size(ctx);