Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/ISSUE_TEMPLATE/010-bug-compilation.yml
#	.github/ISSUE_TEMPLATE/011-bug-results.yml
#	.github/labeler.yml
#	.github/workflows/build.yml
#	.github/workflows/release.yml
#	.gitmodules
#	CMakeLists.txt
#	ggml/CMakeLists.txt
#	ggml/src/CMakeLists.txt
#	ggml/src/ggml-cann/aclnn_ops.cpp
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
#	ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
#	ggml/src/ggml-opencl/kernels/softmax_f16.cl
#	ggml/src/ggml-opencl/kernels/softmax_f32.cl
#	ggml/src/ggml-sycl/element_wise.cpp
#	ggml/src/ggml-sycl/element_wise.hpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	scripts/sync-ggml-am.sh
#	scripts/sync-ggml.last
#	scripts/sync-ggml.sh
#	tests/test-backend-ops.cpp
#	tests/test-c.c
This commit is contained in:
Concedo 2025-07-05 12:16:28 +08:00
commit 57ce374240
64 changed files with 2944 additions and 979 deletions

View file

@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model):
] ]
def set_vocab(self): def set_vocab(self):
with open(self.dir_model / "chat_template.jinja") as f:
# quick hack to make sure chat template is added
self.gguf_writer.add_chat_template(f.read())
super().set_vocab() super().set_vocab()
def set_gguf_parameters(self): def set_gguf_parameters(self):
@ -4781,6 +4778,14 @@ class ARwkv7Model(Rwkv7Model):
class MambaModel(TextModel): class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self): def set_vocab(self):
vocab_size = self.hparams["vocab_size"] vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8 # Round vocab size to next multiple of 8
@ -4855,6 +4860,100 @@ class MambaModel(TextModel):
return [(new_name, data_torch)] return [(new_name, data_torch)]
@ModelBase.register("Mamba2ForCausalLM")
class Mamba2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA2
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
new_name = self.map_tensor_name(name)
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield (new_name, data_torch)
@ModelBase.register("CohereForCausalLM") @ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel): class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R model_arch = gguf.MODEL_ARCH.COMMAND_R
@ -6615,12 +6714,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# maybe we should fallback to text model's arch in that case, since not many models have both # maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {}) text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {}) vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0] arch = None
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
arch = arches[0]
elif "ssm_cfg" in hparams:
# For non-hf Mamba and Mamba2 models
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
# if "architectures" is found in the sub-config, use that instead # if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None: if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0] arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0] arch = vision_config["architectures"][0]
if arch is None:
raise ValueError("Failed to detect model architecture")
return arch return arch

View file

@ -1,50 +0,0 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_KOMPUTE_MAX_DEVICES 16
struct ggml_vk_device {
int index;
int type; // same as VkPhysicalDeviceType
size_t heapSize;
const char * name;
const char * vendor;
int subgroupSize;
uint64_t bufferAlignment;
uint64_t maxAlloc;
};
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
bool ggml_vk_has_vulkan(void);
bool ggml_vk_has_device(void);
struct ggml_vk_device ggml_vk_current_device(void);
//
// backend API
//
// forward declaration
typedef struct ggml_backend * ggml_backend_t;
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
#ifdef __cplusplus
}
#endif

View file

@ -563,6 +563,8 @@ extern "C" {
GGML_GLU_OP_REGLU, GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU, GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU, GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
GGML_GLU_OP_COUNT, GGML_GLU_OP_COUNT,
}; };
@ -659,6 +661,9 @@ extern "C" {
// misc // misc
GGML_API const char * ggml_version(void);
GGML_API const char * ggml_commit(void);
GGML_API void ggml_time_init(void); // call this once at the beginning of the program GGML_API void ggml_time_init(void); // call this once at the beginning of the program
GGML_API int64_t ggml_time_ms(void); GGML_API int64_t ggml_time_ms(void);
GGML_API int64_t ggml_time_us(void); GGML_API int64_t ggml_time_us(void);
@ -1157,6 +1162,22 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: n columns, r rows, // A: n columns, r rows,
// B: n columns, r rows, // B: n columns, r rows,
GGML_API struct ggml_tensor * ggml_glu_split( GGML_API struct ggml_tensor * ggml_glu_split(
@ -1180,6 +1201,16 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// normalize along rows // normalize along rows
GGML_API struct ggml_tensor * ggml_norm( GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -1523,8 +1554,14 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// a [ne0, ne01, ne02, ne03]
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
//
// broadcast:
// ne02 % ne12 == 0
// ne03 % ne13 == 0
//
// fused soft_max(a*scale + mask*(ALiBi slope)) // fused soft_max(a*scale + mask*(ALiBi slope))
// mask is optional
// max_bias = 0.0f for no ALiBi // max_bias = 0.0f for no ALiBi
GGML_API struct ggml_tensor * ggml_soft_max_ext( GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -1987,11 +2024,17 @@ extern "C" {
#define GGML_KQ_MASK_PAD 64 #define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, 1] // q: [n_embd_k, n_batch, n_head, ne3 ]
// k: [n_embd_k, n_kv, n_head_kv, 1] // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !! // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
// n_head % ne32 == 0
// ne3 % ne33 == 0
//
GGML_API struct ggml_tensor * ggml_flash_attn_ext( GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,
@ -2030,7 +2073,8 @@ extern "C" {
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C); struct ggml_tensor * C,
struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View file

@ -61,10 +61,6 @@
#include "ggml-cann.h" #include "ggml-cann.h"
#endif #endif
#ifdef GGML_USE_KOMPUTE
#include "ggml-kompute.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8 // disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__) #if defined(__clang__)
# pragma clang diagnostic push # pragma clang diagnostic push
@ -189,9 +185,6 @@ struct ggml_backend_registry {
#ifdef GGML_USE_RPC #ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg()); register_backend(ggml_backend_rpc_reg());
#endif #endif
#ifdef GGML_USE_KOMPUTE
register_backend(ggml_backend_kompute_reg());
#endif
#ifdef GGML_USE_CPU #ifdef GGML_USE_CPU
register_backend(ggml_backend_cpu_reg()); register_backend(ggml_backend_cpu_reg());
#endif #endif
@ -576,7 +569,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path);

View file

@ -2186,6 +2186,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;

View file

@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
} }
} }
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_geglu_quick
static void ggml_compute_forward_geglu_quick_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_quick_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_quick_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_norm // ggml_compute_forward_norm
static void ggml_compute_forward_norm_f32( static void ggml_compute_forward_norm_f32(
@ -5232,14 +5518,17 @@ static void ggml_compute_forward_soft_max_f32(
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// TODO: handle transposed/permuted matrices
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
//const int64_t ne11 = src1 ? src1->ne[1] : 1; const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
// TODO: is this supposed to be ceil instead of floor? // TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@ -5249,69 +5538,67 @@ static void ggml_compute_forward_soft_max_f32(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int nc = src0->ne[0]; float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const int64_t i11 = i01;
const int64_t i12 = i02%ne12;
const int64_t i13 = i03%ne13;
// ALiBi // ALiBi
const uint32_t h = (i1/ne01)%ne02; // head const uint32_t h = i02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// broadcast the mask across rows // broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_cpy_f32 (ne00, wp, sp);
ggml_vec_scale_f32(nc, wp, scale); ggml_vec_scale_f32(ne00, wp, scale);
if (mp_f32) { if (mp_f32) {
if (use_f16) { if (use_f16) {
for (int i = 0; i < nc; ++i) { for (int i = 0; i < ne00; ++i) {
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
} }
} else { } else {
for (int i = 0; i < nc; ++i) { for (int i = 0; i < ne00; ++i) {
wp[i] += slope*mp_f32[i]; wp[i] += slope*mp_f32[i];
} }
} }
} }
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < ne00; ++i) {
//printf("p[%d] = %f\n", i, p[i]); //printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i])); assert(!isnan(wp[i]));
} }
#endif #endif
float max = -INFINITY; float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp); ggml_vec_max_f32(ne00, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0); assert(sum > 0.0);
sum = 1.0/sum; sum = 1.0/sum;
ggml_vec_scale_f32(nc, dp, sum); ggml_vec_scale_f32(ne00, dp, sum);
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < ne00; ++i) {
assert(!isnan(dp[i])); assert(!isnan(dp[i]));
assert(!isinf(dp[i])); assert(!isinf(dp[i]));
} }
#endif #endif
} }
}
}
} }
void ggml_compute_forward_soft_max( void ggml_compute_forward_soft_max(
@ -7798,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float)); memset(VKQ32, 0, DV*sizeof(float));
} }
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
// k indices // k indices
const int ik3 = iq3 / rk3; const int ik3 = iq3 / rk3;
@ -8336,120 +8623,210 @@ void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan_f32(
const ggml_compute_params * params, const ggml_compute_params * params,
ggml_tensor * dst) { ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // s const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // x const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // dt const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // A const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
const ggml_tensor * src4 = dst->src[4]; // B const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // C const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // dim
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t nh = src1->ne[1]; // n_head
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); // can't use ggml_nbytes because src1 is not necessarily contiguous
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // allows optimizing the modulo since n_group should be a power of 2
// required for per-sequence offsets for states GGML_ASSERT((ng & -ng) == ng);
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // heads per thread
const int dr = (nr + nth - 1)/nth; const int dh = (nh + nth - 1)/nth;
// row range for this thread // head range for this thread
const int ir0 = dr*ith; const int ih0 = dh*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ih1 = MIN(ih0 + dh, nh);
const int ir = ir1 - ir0;
#ifdef __ARM_FEATURE_SVE const int32_t * ids = (const int32_t *) src6->data;
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations for (int i3 = 0; i3 < ns; ++i3) {
if (i2 > 0) { s0 = s; } const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// d_inner for (int i2 = 0; i2 < nt; ++i2) {
for (int i1 = 0; i1 < ir; ++i1) { const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
float x_dt = x[i1] * dt_soft_plus; const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
if (src3->ne[0] == 1) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int ggml_f32_epr = svcntw();
const int ggml_f32_step = 1 * ggml_f32_epr;
const int np = (nc & ~(ggml_f32_step - 1));
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
for (int i = 0; i < np; i += ggml_f32_step) {
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
t0 = GGML_F32_VEC_ADD(t0, t1);
sum = GGML_F32_VEC_FMA(sum, t0, t2);
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
}
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
GGML_F32_VEC az[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#endif
#else
const int np = 0;
#endif
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf;
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
#if defined(__ARM_FEATURE_SVE)
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
svfloat32_t r1_vector = GGML_F32_VEC_ZERO; svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
// d_state
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) { for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
t1 = exp_ps_sve(svptrue_b32(), t1); t1 = exp_ps_sve(svptrue_b32(), t1);
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
} }
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
} #else
}
}
#else
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
// and also because expf is used within the loop.
// d_state // d_state
for (int i0 = 0; i0 < nc; ++i0) { for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc; const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x // state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C) // y = rowwise_dotprod(state, C)
sumf += state * C[i0]; sumf += state * C[ig];
s[i] = state; s[i] = state;
} }
y[i1] = sumf; y[ii] = sumf;
#endif
} }
} }
} }
#endif // use the output as the source when it's not the first token-wise iteration
s0 = s;
}
}
} }
void ggml_compute_forward_ssm_scan( void ggml_compute_forward_ssm_scan(
@ -8688,6 +9065,14 @@ void ggml_compute_forward_glu(
{ {
ggml_compute_forward_swiglu(params, dst); ggml_compute_forward_swiglu(params, dst);
} break; } break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
} break;
case GGML_GLU_OP_GEGLU_QUICK:
{
ggml_compute_forward_geglu_quick(params, dst);
} break;
default: default:
{ {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");

View file

@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)

View file

@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = 0; i < np; i += ggml_f32_step) { for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
} }
// leftovers // leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = np; i < np2; i += ggml_f32_epr) { for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
} }
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) { if (np2 < n) {

View file

@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1); GGML_F32_VEC_STORE(y + i, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
} }
@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
for (int i = np; i < np2; i += ggml_f32_epr) { for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1); GGML_F32_VEC_STORE(y + i, ay1);
} }
@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
} }
} }
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
}
}
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
}
}
#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
}
}
#else
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
}
}
#endif
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
float v = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
}
}
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE #ifndef GGML_USE_ACCELERATE
ggml_float sum = 0.0; ggml_float sum = 0.0;

View file

@ -179,6 +179,20 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif #endif
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \
} \
} while (0)
#else
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else #else

View file

@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x); ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else { } else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else { } else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);

View file

@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -851,7 +853,8 @@ void launch_fattn(
scale, max_bias, m0, m1, n_head_log2, logit_softcap, scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13, nb11, nb12, nb13,
nb21, nb22, nb23, nb21, nb22, nb23,

View file

@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View file

@ -6,7 +6,7 @@
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) __launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_tile_ext_f16( static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ Q, const char * __restrict__ Q,
@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -6,7 +6,7 @@
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) __launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_tile_ext_f32( static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ Q, const char * __restrict__ Q,
@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);

View file

@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
K += nb12*(blockIdx.z / gqa_ratio); K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View file

@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
Q += nb02* blockIdx.z + nb01*ic0; Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio); K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);

View file

@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
const int ne12, const int ne12,
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32,
const int nb31, const int nb31,
const int nb32,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -97,8 +99,8 @@ static __global__ void flash_attn_ext_f16(
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); const half2 * mask2 = (const half2 *) maskh;
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View file

@ -2319,6 +2319,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst); ggml_cuda_op_swiglu(ctx, dst);
break; break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_cuda_op_geglu_quick(ctx, dst);
break;
default: default:
return false; return false;
} }
@ -3121,6 +3127,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]); return ggml_is_contiguous_1(op->src[0]);
default: default:
return false; return false;
@ -3326,12 +3334,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true; return true;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
// (kernel only supports d_state == 128 && d_head % 16 == 0)
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
} else {
// Mamba
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT: case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16; return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
return true; return true;
case GGML_OP_SOFT_MAX_BACK: { case GGML_OP_SOFT_MAX_BACK: {
@ -3380,6 +3402,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 192) { if (op->src[0]->ne[0] == 192) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) { if (op->src[0]->ne[3] != 1) {
return false; return false;
} }

View file

@ -3017,14 +3017,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc); const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;

View file

@ -2,6 +2,7 @@
#include "ggml.h" #include "ggml.h"
#include "softmax.cuh" #include "softmax.cuh"
#include <cstdint> #include <cstdint>
#include <utility>
template <typename T> template <typename T>
static __device__ __forceinline__ float t2f32(T val) { static __device__ __forceinline__ float t2f32(T val) {
@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
return __half2float(val); return __half2float(val);
} }
struct soft_max_params {
int64_t nheads;
uint32_t n_head_log2;
int64_t ncols;
int64_t nrows_x;
int64_t nrows_y;
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int64_t nb11;
int64_t nb12;
int64_t nb13;
int64_t ne12;
int64_t ne13;
float scale;
float max_bias;
float m0;
float m1;
};
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
#ifdef __clang__ #ifdef __clang__
@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
#endif // __clang__ #endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T> template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32( static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float * x, const T * mask, float * dst, const soft_max_params p) {
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension const int64_t i03 = blockIdx.z;
const int64_t i02 = blockIdx.y;
const int64_t i01 = blockIdx.x;
//TODO: noncontigous inputs/outputs
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int64_t i11 = i01;
const int64_t i12 = i02 % p.ne12;
const int64_t i13 = i03 % p.ne13;
x += int64_t(rowx)*ncols; x += int64_t(rowx)*ncols;
mask += int64_t(rowy)*ncols * (mask != nullptr); mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
dst += int64_t(rowx)*ncols; dst += int64_t(rowx)*ncols;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
@ -38,7 +70,7 @@ static __global__ void soft_max_f32(
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@ -55,7 +87,7 @@ static __global__ void soft_max_f32(
break; break;
} }
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val; vals[col] = val;
max_val = max(max_val, val); max_val = max(max_val, val);
@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32(
} }
} }
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, p);
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
}
template<typename T> template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1); const dim3 block_nums(params.ne01, params.ne02, params.ne03);
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const int id = ggml_cuda_get_device();
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { if (nbytes_shared <= smpbo) {
switch (ncols_x) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else { } else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
} }
} }
@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1]; const int64_t nrows_y = src0->ne[1];
const int64_t ne00 = src0->ne[0];
float scale = 1.0f; float scale = 1.0f;
float max_bias = 0.0f; float max_bias = 0.0f;
@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
soft_max_params params = {};
params.nheads = src0->ne[2];
params.n_head_log2 = n_head_log2;
params.ncols = ne00;
params.nrows_x = nrows_x;
params.nrows_y = nrows_y;
params.ne00 = src0->ne[0];
params.ne01 = src0->ne[1];
params.ne02 = src0->ne[2];
params.ne03 = src0->ne[3];
params.nb11 = nb11;
params.nb12 = nb12;
params.nb13 = nb13;
params.ne12 = ne12;
params.ne13 = ne13;
params.scale = scale;
params.max_bias = max_bias;
params.m0 = m0;
params.m1 = m1;
if (use_f16) { if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
} else { } else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
} }
} }

View file

@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2) __global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src2_nb1, const int src2_nb2, const int src3_nb1,
float * __restrict__ dst, const int64_t L) { const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
GGML_UNUSED(src1_nb0); const int64_t s_off, const int64_t d_inner, const int64_t L) {
GGML_UNUSED(src2_nb0);
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int wid = tid / 32; const int wid = tid / 32;
const int wtid = tid % 32; const int wtid = tid % 32;
@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem; float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA; float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
const int stride_s0 = src0_nb1 / sizeof(float); const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb1 / sizeof(float); const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float); const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb1 / sizeof(float); const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb1 / sizeof(float); const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0; const int stride_s = stride_s0;
const int stride_y = stride_x; const int stride_y = d_inner;
// can N not be 16? for example 32? // can N not be 16? for example 32?
if (N == 16) { if (N == 16) {
@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
} }
} }
// assumes as many threads as d_state
template <int splitH, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int head_idx = (blockIdx.x * splitH) / d_head;
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[splitH];
// for the parallel accumulation
__shared__ float stateC[splitH * d_state];
#pragma unroll
for (int j = 0; j < splitH; j++) {
state[j] = s0_block[j * d_state + threadIdx.x];
}
for (int64_t i = 0; i < n_tok; i++) {
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// TODO: only calculate B and C once per head group
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
// across d_head
#pragma unroll
for (int j = 0; j < splitH; j++) {
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B * x_dt);
stateC[j * d_state + threadIdx.x] = state[j] * C;
}
__syncthreads();
// parallel accumulation for stateC
// TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
}
}
// write back the state
#pragma unroll
for (int j = 0; j < splitH; j++) {
s_block[j * d_state + threadIdx.x] = state[j];
}
}
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, const float * src4, const float * src5, const int32_t * src6, float * dst,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) { cudaStream_t stream) {
const int threads = 128; const int threads = 128;
// todo: consider D cannot be divided,does this situation exist? // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
GGML_ASSERT(D % threads == 0); if (src3_nb1 == sizeof(float)) {
const dim3 blocks(B, (D + threads - 1) / threads, 1); // Mamba-2
const int smem_size = (threads * (N + 1) * 2) * sizeof(float); if (d_state == 128) {
if (N == 16) { GGML_ASSERT(d_state % threads == 0);
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( // NOTE: can be any power of two between 4 and 64
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, const int splitH = 16;
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else { } else {
GGML_ABORT("doesn't support N!=16."); GGML_ABORT("doesn't support d_state!=128.");
}
} else {
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
} }
} }
@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // ids
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // head_dim or 1
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t nh = src1->ne[1]; // n_head
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t ng = src4->ne[1]; // n_group
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); const int64_t s_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data; const float * src1_d = (const float *) src1->data;
@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * src3_d = (const float *) src3->data; const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data; const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data; const float * src5_d = (const float *) src5->data;
const int32_t * src6_d = (const int32_t *) src6->data;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
s_off, nc, nr, nh, ng, n_t, n_s, stream);
} }

View file

@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_silu>(ctx, dst); ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
} }
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
}
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
/* silu_back */ /* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) { static __device__ __forceinline__ float op_silu_back(float grad, float x) {

View file

@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -229,7 +229,11 @@ typedef struct {
uint64_t nb21; uint64_t nb21;
uint64_t nb22; uint64_t nb22;
uint64_t nb23; uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31; uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1; int32_t ne1;
int32_t ne2; int32_t ne2;
float scale; float scale;
@ -461,9 +465,21 @@ typedef struct {
} ggml_metal_kargs_sum_rows; } ggml_metal_kargs_sum_rows;
typedef struct { typedef struct {
int64_t ne00; int32_t ne00;
int64_t ne01; int32_t ne01;
int64_t ne02; int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;
@ -499,26 +515,25 @@ typedef struct {
typedef struct { typedef struct {
int64_t d_state; int64_t d_state;
int64_t d_inner; int64_t d_inner;
int64_t n_head;
int64_t n_group;
int64_t n_seq_tokens; int64_t n_seq_tokens;
int64_t n_seqs; int64_t n_seqs;
uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb10; uint64_t nb03;
uint64_t nb11; uint64_t nb11;
uint64_t nb12; uint64_t nb12;
uint64_t nb13; uint64_t nb13;
uint64_t nb20;
uint64_t nb21; uint64_t nb21;
uint64_t nb22; uint64_t nb22;
uint64_t nb30;
uint64_t nb31; uint64_t nb31;
uint64_t nb40;
uint64_t nb41; uint64_t nb41;
uint64_t nb42; uint64_t nb42;
uint64_t nb50; uint64_t nb43;
uint64_t nb51; uint64_t nb51;
uint64_t nb52; uint64_t nb52;
uint64_t nb53;
} ggml_metal_kargs_ssm_scan; } ggml_metal_kargs_ssm_scan;
typedef struct { typedef struct {

View file

@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@ -529,6 +530,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REGLU, GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU, GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@ -1196,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@ -1508,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@ -1691,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default: default:
return false; return false;
@ -1725,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM: case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
@ -2454,6 +2462,12 @@ static bool ggml_metal_encode_node(
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break; break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
case GGML_GLU_OP_GEGLU_QUICK:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -2644,10 +2658,7 @@ static bool ggml_metal_encode_node(
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0); const uint32_t n_head = src0->ne[2];
const int64_t nrows_y = src0->ne[1];
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -2707,6 +2718,18 @@ static bool ggml_metal_encode_node(
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.scale =*/ scale, /*.scale =*/ scale,
/*.max_bias =*/ max_bias, /*.max_bias =*/ max_bias,
/*.m0 =*/ m0, /*.m0 =*/ m0,
@ -2726,7 +2749,7 @@ static bool ggml_metal_encode_node(
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {
@ -2800,71 +2823,91 @@ static bool ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];
GGML_ASSERT(src3); GGML_ASSERT(src3);
GGML_ASSERT(src4); GGML_ASSERT(src4);
GGML_ASSERT(src5); GGML_ASSERT(src5);
GGML_ASSERT(src6);
size_t offs_src3 = 0; size_t offs_src3 = 0;
size_t offs_src4 = 0; size_t offs_src4 = 0;
size_t offs_src5 = 0; size_t offs_src5 = 0;
size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
const uint64_t nb30 = src3->nb[0]; const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1]; const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
const uint64_t nb40 = src4->nb[0]; const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1]; const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2]; const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
const uint64_t nb50 = src5->nb[0]; const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1]; const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2]; const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
const int64_t d_state = ne00; const int64_t d_state = ne00;
const int64_t d_inner = ne01; const int64_t d_inner = ne01;
const int64_t n_seq_tokens = ne11; const int64_t n_head = ne02;
const int64_t n_seqs = ne02; const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne12;
const int64_t n_seqs = ne13;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; id<MTLComputePipelineState> pipeline = nil;
if (ne30 == 1) {
// Mamba-2
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}
ggml_metal_kargs_ssm_scan args = { ggml_metal_kargs_ssm_scan args = {
/*.d_state =*/ d_state, /*.d_state =*/ d_state,
/*.d_inner =*/ d_inner, /*.d_inner =*/ d_inner,
/*.n_head =*/ n_head,
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens, /*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs, /*.n_seqs =*/ n_seqs,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01, /*.nb01 =*/ nb01,
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb10 =*/ nb10, /*.nb03 =*/ nb03,
/*.nb11 =*/ nb11, /*.nb11 =*/ nb11,
/*.nb12 =*/ nb12, /*.nb12 =*/ nb12,
/*.nb13 =*/ nb13, /*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21, /*.nb21 =*/ nb21,
/*.nb22 =*/ nb22, /*.nb22 =*/ nb22,
/*.nb30 =*/ nb30,
/*.nb31 =*/ nb31, /*.nb31 =*/ nb31,
/*.nb40 =*/ nb40,
/*.nb41 =*/ nb41, /*.nb41 =*/ nb41,
/*.nb42 =*/ nb42, /*.nb42 =*/ nb42,
/*.nb50 =*/ nb50, /*.nb43 =*/ nb43,
/*.nb51 =*/ nb51, /*.nb51 =*/ nb51,
/*.nb52 =*/ nb52, /*.nb52 =*/ nb52,
/*.nb53 =*/ nb53,
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -2874,10 +2917,17 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBytes:&args length:sizeof(args) atIndex:7]; [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break; } break;
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
{ {
@ -4979,7 +5029,11 @@ static bool ggml_metal_encode_node(
/*.nb21 =*/ nb21, /*.nb21 =*/ nb21,
/*.nb22 =*/ nb22, /*.nb22 =*/ nb22,
/*.nb23 =*/ nb23, /*.nb23 =*/ nb23,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31, /*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1, /*.ne1 =*/ ne1,
/*.ne2 =*/ ne2, /*.ne2 =*/ ne2,
/*.scale =*/ scale, /*.scale =*/ scale,

View file

@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
} }
void quantize_q4_0(device const float * src, device block_q4_0 & dst) { void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f; float max = 0.0f;
@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
} }
void quantize_q5_0(device const float * src, device block_q5_0 & dst) { void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f; float max = 0.0f;
@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
} }
void quantize_q8_0(device const float * src, device block_q8_0 & dst) { void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) { for (int j = 0; j < QK8_0; j++) {
@ -1258,6 +1261,50 @@ kernel void kernel_swiglu(
} }
} }
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_quick(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
template <bool norm> template <bool norm>
kernel void kernel_sum_rows( kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args, constant ggml_metal_kargs_sum_rows & args,
@ -1320,24 +1367,28 @@ kernel void kernel_soft_max(
device char * dst, device char * dst,
constant ggml_metal_kargs_soft_max & args, constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) { uint3 tptg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (args.ne02*args.ne01); const int32_t i03 = tgpig.z;
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; const int32_t i02 = tgpig.y;
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); const int32_t i01 = tgpig.x;
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); const int32_t i13 = i03%args.ne13;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; const int32_t i12 = i02%args.ne12;
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); const int32_t i11 = i01;
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const int64_t h = i02; const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@ -1348,13 +1399,13 @@ kernel void kernel_soft_max(
// parallel max // parallel max
float lmax = -INFINITY; float lmax = -INFINITY;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
} }
// find the max value in the block // find the max value in the block
float max_val = simd_max(lmax); float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) { if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) { if (sgitg == 0) {
buf[tiisg] = -INFINITY; buf[tiisg] = -INFINITY;
} }
@ -1373,7 +1424,7 @@ kernel void kernel_soft_max(
// parallel sum // parallel sum
float lsum = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0; lsum += exp_psrc0;
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
@ -1385,7 +1436,7 @@ kernel void kernel_soft_max(
float sum = simd_sum(lsum); float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) { if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) { if (sgitg == 0) {
buf[tiisg] = 0.0f; buf[tiisg] = 0.0f;
} }
@ -1404,7 +1455,7 @@ kernel void kernel_soft_max(
const float inv_sum = 1.0f/sum; const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum; pdst[i00] *= inv_sum;
} }
} }
@ -1416,23 +1467,27 @@ kernel void kernel_soft_max_4(
device char * dst, device char * dst,
constant ggml_metal_kargs_soft_max & args, constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) { uint3 tptg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (args.ne02*args.ne01); const int32_t i03 = tgpig.z;
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; const int32_t i02 = tgpig.y;
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); const int32_t i01 = tgpig.x;
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; const int32_t i13 = i03%args.ne13;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; const int32_t i12 = i02%args.ne12;
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; const int32_t i11 = i01;
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f; float slope = 1.0f;
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const int64_t h = i02; const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
@ -1443,14 +1498,14 @@ kernel void kernel_soft_max_4(
// parallel max // parallel max
float4 lmax4 = -INFINITY; float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
} }
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax); float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) { if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) { if (sgitg == 0) {
buf[tiisg] = -INFINITY; buf[tiisg] = -INFINITY;
} }
@ -1469,7 +1524,7 @@ kernel void kernel_soft_max_4(
// parallel sum // parallel sum
float4 lsum4 = 0.0f; float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4; lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4; pdst4[i00] = exp_psrc4;
@ -1483,7 +1538,7 @@ kernel void kernel_soft_max_4(
float sum = simd_sum(lsum); float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) { if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) { if (sgitg == 0) {
buf[tiisg] = 0.0f; buf[tiisg] = 0.0f;
} }
@ -1502,7 +1557,7 @@ kernel void kernel_soft_max_4(
const float inv_sum = 1.0f/sum; const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum; pdst4[i00] *= inv_sum;
} }
} }
@ -1588,7 +1643,7 @@ kernel void kernel_ssm_conv_f32(
x[0] = sumf; x[0] = sumf;
} }
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32( kernel void kernel_ssm_scan_f32(
device const void * src0, device const void * src0,
device const void * src1, device const void * src1,
@ -1596,46 +1651,119 @@ kernel void kernel_ssm_scan_f32(
device const void * src3, device const void * src3,
device const void * src4, device const void * src4,
device const void * src5, device const void * src5,
device const void * src6,
device float * dst, device float * dst,
constant ggml_metal_kargs_ssm_scan & args, constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x; const int64_t i1 = 0;
const int64_t i3 = tgpig.y; const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state; const int64_t nc = args.d_state;
// const int64_t nr = args.d_inner; const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens; const int64_t n_t = args.n_seq_tokens;
// const int64_t n_s = args.n_seqs;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) { for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
if (i2 > 0) { const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
s0 = s; const float x_dt = x[0] * dt_soft_plus;
}
// i1 == 0
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) { for (int64_t i0 = 0; i0 < nc; ++i0) {
int64_t i = i0; const int64_t i = i0 + i1*nc;
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0]; sumf += state * C[i0];
s[i] = state; s[i] = state;
} }
y[0] = sumf; y[0] = sumf;
// recurse
s0 = s;
}
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
// recurse
s0 = s;
} }
} }
@ -3776,7 +3904,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory // load the mask in shared memory
#pragma unroll(Q) #pragma unroll(Q)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
const float m = pm[ic + tiisg]; const float m = pm[ic + tiisg];
@ -4262,7 +4390,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q; const bool has_mask = mask != q;
// pointer to the mask // pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31); device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f; float slope = 1.0f;

View file

@ -1,7 +1,9 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define GELU_COEF_A 0.044715f #define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
#define SQRT_2_INV 0.70710678118654752440084436210484f
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// geglu // geglu
@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
dst_row[i0] = silu*x1; dst_row[i0] = silu*x1;
} }
} }
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------
kernel void kernel_geglu_erf(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_erf_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
//------------------------------------------------------------------------------
// geglu_quick
//------------------------------------------------------------------------------
kernel void kernel_geglu_quick(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
kernel void kernel_geglu_quick_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}

View file

@ -240,6 +240,21 @@ enum vk_device_architecture {
INTEL_XE2, INTEL_XE2,
}; };
// HSK x HSV
enum FaHeadSizes {
FA_HEAD_SIZE_64,
FA_HEAD_SIZE_80,
FA_HEAD_SIZE_96,
FA_HEAD_SIZE_112,
FA_HEAD_SIZE_128,
FA_HEAD_SIZE_192,
FA_HEAD_SIZE_192_128,
FA_HEAD_SIZE_256,
FA_HEAD_SIZE_576_512,
FA_HEAD_SIZE_UNSUPPORTED,
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
};
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties(); vk::PhysicalDeviceProperties props = device.getProperties();
@ -457,6 +472,8 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2]; vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2]; vk_pipeline pipeline_swiglu[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_silu_back_f32;
@ -483,26 +500,11 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_cwhn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_flash_attn_split_k_reduce;
@ -649,6 +651,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev2; uint32_t nev2;
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nem2;
uint32_t nb01; uint32_t nb01;
uint32_t nb02; uint32_t nb02;
@ -659,7 +662,6 @@ struct vk_flash_attn_push_constants {
uint32_t nb21; uint32_t nb21;
uint32_t nb22; uint32_t nb22;
uint32_t nb23; uint32_t nb23;
uint32_t nb31;
float scale; float scale;
float max_bias; float max_bias;
@ -674,6 +676,7 @@ struct vk_flash_attn_push_constants {
uint32_t split_kv; uint32_t split_kv;
uint32_t k_num; uint32_t k_num;
}; };
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
struct vk_op_push_constants { struct vk_op_push_constants {
uint32_t KX; uint32_t KX;
@ -772,6 +775,14 @@ struct vk_op_rope_push_constants {
struct vk_op_soft_max_push_constants { struct vk_op_soft_max_push_constants {
uint32_t KX; uint32_t KX;
uint32_t KY; uint32_t KY;
uint32_t ne00;
uint32_t ne01;
uint32_t ne02;
uint32_t ne12;
uint32_t ne13;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;
@ -1010,7 +1021,7 @@ struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the // number of additional consecutive nodes that are being fused with the
// node currently being processed // node currently being processed
uint32_t num_additional_fused_ops {}; int num_additional_fused_ops {};
}; };
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@ -1706,6 +1717,35 @@ enum FaCodePath {
FA_COOPMAT2, FA_COOPMAT2,
}; };
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
if (hsk != 192 && hsk != 576 && hsk != hsv) {
return FA_HEAD_SIZE_UNSUPPORTED;
}
switch (hsk) {
case 64: return FA_HEAD_SIZE_64;
case 80: return FA_HEAD_SIZE_80;
case 96: return FA_HEAD_SIZE_96;
case 112: return FA_HEAD_SIZE_112;
case 128: return FA_HEAD_SIZE_128;
case 192:
if (hsv == 192) {
return FA_HEAD_SIZE_192;
} else if (hsv == 128) {
return FA_HEAD_SIZE_192_128;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
case 256: return FA_HEAD_SIZE_256;
case 576:
if (hsv == 512) {
return FA_HEAD_SIZE_576_512;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
default: return FA_HEAD_SIZE_UNSUPPORTED;
}
}
// number of rows/cols for flash attention shader // number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
@ -1726,8 +1766,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
} }
} }
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp); GGML_UNUSED(clamp);
GGML_UNUSED(hsv);
if (path == FA_SCALAR) { if (path == FA_SCALAR) {
if (small_rows) { if (small_rows) {
@ -1751,7 +1792,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
} }
// small cols to reduce register count // small cols to reduce register count
if (ggml_is_quantized(type) || D == 256) { if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32}; return {64, 32};
} }
return {64, 64}; return {64, 64};
@ -2044,19 +2085,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
}; };
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
}; };
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best. // For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80. // can't use 256 for D==80.
// For scalar, use 128 (arbitrary) // For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size ? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128); : ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -2065,26 +2108,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
}; };
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@ -2793,6 +2839,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_GLU(geglu) CREATE_GLU(geglu)
CREATE_GLU(reglu) CREATE_GLU(reglu)
CREATE_GLU(swiglu) CREATE_GLU(swiglu)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU #undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@ -3703,7 +3751,6 @@ static void ggml_vk_instance_init() {
} }
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@ -6017,24 +6064,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
// Needs to be kept up to date on shader changes // Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows; const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
const uint32_t masksh = Bc * Br * sizeof(float);
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t acctype = f32acc ? 4 : 2; const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8; const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype; const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype; const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = D / 4 + 2; const uint32_t kshstride = hsk / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float); const uint32_t slope = Br * sizeof(float);
@ -6042,7 +6112,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported; return supported;
} }
@ -6064,13 +6134,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_TENSOR_LOCALS(size_t, nb, dst, nb) GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t D = neq0; const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
uint32_t N = neq1; uint32_t N = neq1;
const uint32_t KV = nek1; const uint32_t KV = nek1;
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == HSV);
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous // input tensor rows must be contiguous
@ -6078,12 +6149,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == HSK);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(nev1 == nek1); GGML_ASSERT(nev1 == nek1);
@ -6104,7 +6172,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) { if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR; path = FA_SCALAR;
@ -6157,47 +6225,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
path = FA_SCALAR; path = FA_SCALAR;
} }
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR &&
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
small_rows = true;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
switch (path) { switch (path) {
case FA_SCALAR: case FA_SCALAR:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
case FA_COOPMAT1: case FA_COOPMAT1:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
case FA_COOPMAT2: case FA_COOPMAT2:
switch (D) { pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
break; break;
default: default:
GGML_ASSERT(0); GGML_ASSERT(0);
@ -6227,7 +6273,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Try to use split_k when KV is large enough to be worth the overhead // Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM. // Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / workgroups_y; split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) { if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple // Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that. // of "align", so recompute split_k based on that.
@ -6237,9 +6283,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) { if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large"); GGML_ABORT("Requested preallocation size is too large");
} }
@ -6331,11 +6377,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
(uint32_t)neq2, (uint32_t)neq3, (uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3, (uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3, (uint32_t)nev2, (uint32_t)nev3,
nem1, nem1, nem2,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1,
scale, max_bias, logit_softcap, scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1, mask != nullptr, n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k }; gqa_ratio, split_kv, split_k };
@ -6358,13 +6403,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k }; const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{ {
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
}, },
pc2, { (uint32_t)ne1, 1, 1 }); pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
} else { } else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
@ -6558,6 +6603,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
default: default:
break; break;
} }
@ -7690,7 +7739,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
const uint32_t nrows_y = (uint32_t)src0->ne[1]; const uint32_t nrows_y = (uint32_t)src0->ne[1];
const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
const uint32_t n_head_kv = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -7699,6 +7754,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols, ncols,
src1 != nullptr ? nrows_y : (uint32_t)0, src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
ne12, ne13,
nb11, nb12, nb13,
scale, max_bias, scale, max_bias,
m0, m1, m0, m1,
n_head_log2, n_head_log2,
@ -8893,6 +8951,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break; break;
default: default:
return false; return false;
@ -9140,6 +9200,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break; break;
default: default:
@ -9358,6 +9420,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer; buf = tensor->buffer;
break; break;
default: default:
@ -10168,6 +10232,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU: case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) && return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@ -10248,19 +10314,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device); auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2; bool coopmat2 = device->coopmat2;
switch (op->src[0]->ne[0]) { FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
case 64: if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
case 80:
case 96:
case 112:
case 128:
case 256:
break;
default:
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false; return false;
} }
if (op->src[0]->type != GGML_TYPE_F32) { if (op->src[0]->type != GGML_TYPE_F32) {
@ -10272,6 +10327,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
return false;
}
// It's straightforward to support different K/V dequant, but would // It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines // significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) { if (op->src[1]->type != op->src[2]->type) {
@ -10340,6 +10401,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false; return false;
} }
} break; } break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:
@ -10430,6 +10497,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:

View file

@ -11,7 +11,8 @@
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp" #include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split; const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter; const uint32_t cols_per_thread = Bc / cols_per_iter;
@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
return elem; return elem;
} }
@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br]; shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4]; shared vec4 Qf[Br][HSK / 4];
void main() { void main() {
#ifdef NEEDS_INIT_IQ_SHMEM #ifdef NEEDS_INIT_IQ_SHMEM
@ -53,18 +54,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (D / 4); uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < D / 4 && if (r < Br && d < HSK / 4 &&
i * Br + r < N) { i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
} }
} }
barrier(); barrier();
vec4 Of[Br][D_per_thread / 4]; vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0); Of[r][d] = vec4(0.0);
} }
@ -99,6 +100,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif #endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
}
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
@ -112,7 +117,7 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -150,7 +155,7 @@ void main() {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) { if (idx + tid < Bc * Br) {
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} }
} }
barrier(); barrier();
@ -191,14 +196,14 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d]; Of[r][d] = eMf[r] * Of[r][d];
} }
} }
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -255,7 +260,7 @@ void main() {
Lf[r] = tmpsh[d_tid]; Lf[r] = tmpsh[d_tid];
barrier(); barrier();
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d]; Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d]; tmpshv4[tid] = Of[r][d];
@ -277,11 +282,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index; uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
} }
@ -289,7 +294,7 @@ void main() {
} }
} }
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@ -305,18 +310,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = 1.0 / Lf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r]; Of[r][d] *= Lfrcp[r];
} }
} }
uint32_t o_offset = iq3*p.ne2*p.ne1; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) { if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
} }
@ -326,9 +331,9 @@ void main() {
} else { } else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) { if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
} }
} }
} }

View file

@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t Clamp = 0; layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t D_split = 16; layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint32_t N; uint32_t N;
@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
uint32_t nev2; uint32_t nev2;
uint32_t nev3; uint32_t nev3;
uint32_t nem1; uint32_t nem1;
uint32_t nem2;
uint32_t nb01; uint32_t nb01;
uint32_t nb02; uint32_t nb02;
@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
uint32_t nb21; uint32_t nb21;
uint32_t nb22; uint32_t nb22;
uint32_t nb23; uint32_t nb23;
uint32_t nb31;
float scale; float scale;
float max_bias; float max_bias;

View file

@ -13,7 +13,9 @@
#include "types.comp" #include "types.comp"
#include "flash_attn_base.comp" #include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split; const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4; const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split; const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
return elem; return elem;
} }
@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = D / 4 + 2; // in units of f16vec4 const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride]; shared f16vec4 Qf[Br * qstride];
// Avoid padding for D==256 to make it fit in 48KB shmem. // Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride]; shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride]; shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br]; shared float slope[Br];
@ -74,18 +76,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (D / 4); uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < D / 4 && if (r < Br && d < HSK / 4 &&
i * Br + r < N) { i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
} }
} }
barrier(); barrier();
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0); Of[r][d] = ACC_TYPEV4(0.0);
} }
@ -123,14 +125,18 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif #endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
}
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4); uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (D / 4); uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < D / 4) { if (c < Bc && d < HSK / 4) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -145,14 +151,14 @@ void main() {
} }
barrier(); barrier();
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it // This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < D / 16; ++d) { for (uint32_t d = 0; d < HSK / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@ -181,7 +187,7 @@ void main() {
uint32_t c = (idx + tid) % Bc; uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc; uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
} }
} }
barrier(); barrier();
@ -202,7 +208,7 @@ void main() {
eMf[r] = exp(Moldf - Mf[r]); eMf[r] = exp(Moldf - Mf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d]; Of[r][d] = float16_t(eMf[r]) * Of[r][d];
} }
@ -217,7 +223,7 @@ void main() {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r]; Lf[r] += Pf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -280,7 +286,7 @@ void main() {
} }
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d]; Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d]; tmpshv4[tid] = Of[r][d];
@ -300,11 +306,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index; uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
} }
@ -312,7 +318,7 @@ void main() {
} }
} }
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@ -328,18 +334,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = 1.0 / Lf[r];
} }
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]); Of[r][d] *= float16_t(Lfrcp[r]);
} }
} }
uint32_t o_offset = iq3*p.ne2*p.ne1; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) { if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
} }
@ -349,9 +355,9 @@ void main() {
} else { } else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) { if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
} }
} }
} }

View file

@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
// Rows index by Q's dimension 2, and the first N rows are valid. // Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{ {
if (r < N && c < D) { if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * D + c; uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem); data_o[o_offset + offset] = D_TYPE(elem);
} }
return elem; return elem;
@ -86,9 +86,9 @@ void main() {
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif #endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader // hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV) if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@ -104,16 +104,16 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q; coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q); Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale); Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
@ -130,15 +130,20 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
} }
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]] [[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) { for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T; coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S); S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) { if (p.logit_softcap != 0.0f) {
@ -155,7 +160,7 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} }
@ -203,42 +208,42 @@ void main() {
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum); rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V; coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
L = eM*L + rowsum; L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise // This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared // multiply rather than matrix multiply it has the diagonal element smeared
// across the row // across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce // resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O. // multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV); PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV); O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
} }
// If there is split_k, then the split_k resolve shader does the final // If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values. // division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) { if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * split_k_index; uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return; return;
} }
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce // resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@ -250,18 +255,18 @@ void main() {
O = Ldiag*O; O = Ldiag*O;
uint32_t o_offset = iq3*p.ne2*p.ne1; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) { if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else { } else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions // permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
} }
} }

View file

@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint D; uint D;
uint N; uint N;
uint ne3;
uint k_num; uint k_num;
} p; } p;
@ -19,13 +20,14 @@ void main() {
// Each workgroup handles a row // Each workgroup handles a row
const uint n = gl_WorkGroupID.x; const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
uint D = p.D; uint D = p.D;
uint N = p.N; uint N = p.N;
uint k_num = p.k_num; uint k_num = p.k_num;
uint l_offset = D * N * k_num + n; uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * k_num + N + n; uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2; uint lm_stride = N * 2;
// Compute the max m value for the row // Compute the max m value for the row
@ -49,11 +51,11 @@ void main() {
for (uint d = tid; d < D; d += BLOCK_SIZE) { for (uint d = tid; d < D; d += BLOCK_SIZE) {
float O = 0.0; float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) { [[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d; uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride]; float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset]; O += exp(m - m_max) * data_a[o_offset];
} }
O *= L; O *= L;
data_d[D * n + d] = O; data_d[iq3 * D * N + D * n + d] = O;
} }
} }

View file

@ -0,0 +1,27 @@
#version 450
#include "glu_head.comp"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.comp"

View file

@ -0,0 +1,11 @@
#version 450
#include "glu_head.comp"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.comp"

View file

@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
{ {
uint KX; uint KX;
uint KY; uint KY;
uint ne00;
uint ne01;
uint ne02;
uint ne12;
uint ne13;
uint nb11;
uint nb12;
uint nb13;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;
@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
void soft_max(uint num_iters) { void soft_max(uint num_iters) {
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
const uint32_t i01 = rowx % p.ne01;
uint rowy_start = 0;
if (p.KY > 0) {
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
}
if (rowx >= p.nrows_x) { if (rowx >= p.nrows_x) {
return; return;
@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
// ALiBi // ALiBi
if (p.max_bias > 0.0f) { if (p.max_bias > 0.0f) {
const uint h = rowx/p.KY; // head index const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1; const float base = h < p.n_head_log2 ? p.m0 : p.m1;
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
FLOAT_TYPE b = FLOAT_TYPE(0); FLOAT_TYPE b = FLOAT_TYPE(0);
if (p.KY > 0 && col < p.KX) { if (p.KY > 0 && col < p.KX) {
b = data_b[rowy * p.KX + col]; b = data_b[rowy_start + col];
} }
FLOAT_TYPE v = a * p.scale + slope * b; FLOAT_TYPE v = a * p.scale + slope * b;
@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
if (idx < DATA_CACHE_SIZE) { if (idx < DATA_CACHE_SIZE) {
val = exp(data_cache[idx] - max_val); val = exp(data_cache[idx] - max_val);
} else { } else {
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
} }
sum += val; sum += val;
if (idx < DATA_CACHE_SIZE) { if (idx < DATA_CACHE_SIZE) {

View file

@ -607,6 +607,10 @@ void process_shaders() {
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

View file

@ -21,6 +21,9 @@
#include <alloca.h> #include <alloca.h>
#endif #endif
#define GGML_VERSION "0.0.1"
#define GGML_COMMIT "KCPP"
#include <assert.h> #include <assert.h>
#include <errno.h> #include <errno.h>
#include <time.h> #include <time.h>
@ -474,6 +477,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
} }
const char * ggml_version(void) {
return GGML_VERSION;
}
const char * ggml_commit(void) {
return GGML_COMMIT;
}
// //
// timing // timing
// //
@ -1145,9 +1156,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU", "REGLU",
"GEGLU", "GEGLU",
"SWIGLU", "SWIGLU",
"GEGLU_ERF",
"GEGLU_QUICK",
}; };
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@ -2773,6 +2786,48 @@ struct ggml_tensor * ggml_swiglu_split(
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
} }
// ggml_geglu_erf
struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
}
struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
}
struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
}
// ggml_geglu_quick
struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
}
struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
}
struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
}
// ggml_norm // ggml_norm
static struct ggml_tensor * ggml_norm_impl( static struct ggml_tensor * ggml_norm_impl(
@ -3679,9 +3734,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) { if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]); GGML_ASSERT(mask->ne[1] >= a->ne[1]);
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
} }
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
@ -4702,13 +4758,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
GGML_ASSERT(ggml_can_mul_mat(k, q)); GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT) // TODO: check if vT can be multiplied by (k*qT)
GGML_ASSERT(q->ne[3] == k->ne[3]);
GGML_ASSERT(q->ne[3] == v->ne[3]);
if (mask) { if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
} }
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
@ -4836,7 +4896,6 @@ struct ggml_tensor * ggml_ssm_conv(
const int64_t n_s = sx->ne[2]; const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1? // TODO: maybe support other strides than 1?
// FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0); GGML_ASSERT(n_t >= 0);
@ -4859,36 +4918,49 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C) { struct ggml_tensor * C,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A)); GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt)); GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C)); GGML_ASSERT(ggml_are_same_shape(B, C));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1]; const int64_t head_dim = x->ne[0];
const int64_t n_seq_tokens = x->ne[1]; const int64_t n_head = x->ne[1];
const int64_t n_seqs = x->ne[2]; const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];
GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens); GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs); GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
if (A->ne[0] != 1) {
// Mamba-1 has more granular decay factors
GGML_ASSERT(A->ne[0] == d_state);
}
} }
// concatenated y + ssm_states // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
result->src[0] = s; result->src[0] = s;
@ -4897,6 +4969,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A; result->src[3] = A;
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = ids;
return result; return result;
} }
@ -6037,13 +6110,28 @@ static void ggml_compute_backward(
} }
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break; } break;
case GGML_OP_GLU: {
switch (ggml_get_glu_op(tensor)) {
case GGML_GLU_OP_SWIGLU: {
if (src0_needs_grads) {
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
}
} break;
default: {
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
} //break;
}
} break;
case GGML_OP_NONE: { case GGML_OP_NONE: {
// noop // noop
} break; } break;
case GGML_OP_COUNT: case GGML_OP_COUNT:
default: { default: {
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
GGML_ABORT("fatal error");
} //break; } //break;
} }

View file

@ -170,6 +170,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size" INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size" STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class WKV: class WKV:
@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum):
RWKV7 = auto() RWKV7 = auto()
ARWKV7 = auto() ARWKV7 = auto()
MAMBA = auto() MAMBA = auto()
MAMBA2 = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
COHERE2 = auto() COHERE2 = auto()
@ -429,6 +431,7 @@ class MODEL_TENSOR(IntEnum):
SSM_DT = auto() SSM_DT = auto()
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto() SSM_OUT = auto()
TIME_MIX_W0 = auto() TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto() TIME_MIX_W1 = auto()
@ -628,6 +631,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2", MODEL_ARCH.COHERE2: "cohere2",
@ -730,6 +734,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
@ -1714,6 +1719,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.MAMBA2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [ MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
@ -2497,6 +2515,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization # tokenization

View file

@ -714,8 +714,8 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None: def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_shared_kv_layers(self, value: float) -> None: def add_shared_kv_layers(self, value: int) -> None:
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
@ -861,6 +861,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_group_count(self, value: int) -> None:
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

View file

@ -477,7 +477,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
), ),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
@ -574,6 +574,10 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D",
), ),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",

View file

@ -245,9 +245,18 @@ class SpecialVocab:
if not tokenizer_config: if not tokenizer_config:
return True return True
chat_template_alt = None chat_template_alt = None
chat_template_file = path / 'chat_template.json' chat_template_json = path / 'chat_template.json'
if chat_template_file.is_file(): chat_template_jinja = path / 'chat_template.jinja'
with open(chat_template_file, encoding = 'utf-8') as f: if chat_template_jinja.is_file():
with open(chat_template_jinja, encoding = 'utf-8') as f:
chat_template_alt = f.read()
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
for template_path in additional_templates:
with open(template_path, encoding = 'utf-8') as fp:
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
elif chat_template_json.is_file():
with open(chat_template_json, encoding = 'utf-8') as f:
chat_template_alt = json.load(f).get('chat_template') chat_template_alt = json.load(f).get('chat_template')
chat_template = tokenizer_config.get('chat_template', chat_template_alt) chat_template = tokenizer_config.get('chat_template', chat_template_alt)
if chat_template is None or isinstance(chat_template, (str, list)): if chat_template is None or isinstance(chat_template, (str, list)):

View file

@ -3208,6 +3208,7 @@ Current version indicated by LITEVER below.
var schedule_multiplayer_major_change = false; var schedule_multiplayer_major_change = false;
var last_request_str = "No Requests Available"; //full context of last submitted request var last_request_str = "No Requests Available"; //full context of last submitted request
var last_response_obj = null; var last_response_obj = null;
var last_response_streamlog = "";
var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp
var kai_poll_recoverykey = ""; //for recovering a lost polled streaming in case of disconnect. var kai_poll_recoverykey = ""; //for recovering a lost polled streaming in case of disconnect.
var globalabortcontroller = null; var globalabortcontroller = null;
@ -5634,6 +5635,10 @@ Current version indicated by LITEVER below.
}, },
transform(chunk, ctrl) { transform(chunk, ctrl) {
ctrl.buf += chunk; ctrl.buf += chunk;
if(chunk)
{
last_response_streamlog += chunk;
}
let evs = []; let evs = [];
let m; let m;
while ((m = /^data: ?(.*)(\r?\n){2}/m.exec(ctrl.buf)) !== null) { while ((m = /^data: ?(.*)(\r?\n){2}/m.exec(ctrl.buf)) !== null) {
@ -9343,6 +9348,10 @@ Current version indicated by LITEVER below.
{ {
lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj); lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj);
} }
if(last_response_streamlog)
{
lr += "\n\nResponse:\n" + last_response_streamlog;
}
msgbox(lr,"Last Request Info",false); msgbox(lr,"Last Request Info",false);
} }
function show_last_logprobs() function show_last_logprobs()
@ -10401,7 +10410,7 @@ Current version indicated by LITEVER below.
desired_oai_ep = transform_oai_ep(desired_oai_ep); desired_oai_ep = transform_oai_ep(desired_oai_ep);
let oaiheaders = {}; let oaiheaders = {};
if(desired_oai_key!=""){ if(desired_oai_key!="" && !desired_oai_ep.toLowerCase().includes("pollinations.ai")){
oaiheaders["Authorization"] = "Bearer " + desired_oai_key; oaiheaders["Authorization"] = "Bearer " + desired_oai_key;
}; };
if (desired_oai_ep.toLowerCase().includes("api.mistral.ai")) { if (desired_oai_ep.toLowerCase().includes("api.mistral.ai")) {
@ -14041,6 +14050,7 @@ Current version indicated by LITEVER below.
redo_arr = []; redo_arr = [];
last_request_str = "No Requests Available"; last_request_str = "No Requests Available";
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
retry_prev_text = []; retry_prev_text = [];
retry_preserve_last = false; retry_preserve_last = false;
retry_in_progress = false; retry_in_progress = false;
@ -16471,6 +16481,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(submit_payload); last_request_str = JSON.stringify(submit_payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) { if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) {
let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint); let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint);
kobold_api_stream_sse(sub_endpt, submit_payload); kobold_api_stream_sse(sub_endpt, submit_payload);
@ -16643,11 +16654,15 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(oai_payload); last_request_str = JSON.stringify(oai_payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
let oaiheaders = { let oaiheaders = {
'Content-Type': 'application/json', 'Content-Type': 'application/json'
'Authorization': 'Bearer ' + custom_oai_key
}; };
if (!targetep.toLowerCase().includes("pollinations.ai")) {
oaiheaders['Authorization'] = 'Bearer ' + custom_oai_key;
}
if(targetep.toLowerCase().includes("openrouter.ai")) if(targetep.toLowerCase().includes("openrouter.ai"))
{ {
oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net"; oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net";
@ -16780,6 +16795,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(claude_payload); last_request_str = JSON.stringify(claude_payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
let claudeheaders = { let claudeheaders = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -16970,6 +16986,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(payload); last_request_str = JSON.stringify(payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
let geminiheaders = { 'Content-Type': 'application/json' }; let geminiheaders = { 'Content-Type': 'application/json' };
if(is_browser_supports_sse() && localsettings.tokenstreammode!=0) if(is_browser_supports_sse() && localsettings.tokenstreammode!=0)
@ -17018,6 +17035,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(cohere_payload); last_request_str = JSON.stringify(cohere_payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
let cohere_headers = { let cohere_headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + custom_cohere_key 'Authorization': 'Bearer ' + custom_cohere_key
@ -17093,6 +17111,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(submit_payload); last_request_str = JSON.stringify(submit_payload);
last_response_obj = null; last_response_obj = null;
last_response_streamlog = "";
fetch(horde_submit_endpoint, { fetch(horde_submit_endpoint, {
method: 'POST', // or 'PUT' method: 'POST', // or 'PUT'

View file

@ -62,7 +62,7 @@ dry_seq_break_max = 128
extra_images_max = 4 extra_images_max = 4
# global vars # global vars
KcppVersion = "1.95.1" KcppVersion = "1.96"
showdebug = True showdebug = True
kcpp_instance = None #global running instance kcpp_instance = None #global running instance
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}

View file

@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_COHERE2, "cohere2" },
@ -170,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@ -1004,6 +1006,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
}, },
}, },
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{ {
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
{ {
@ -1761,6 +1779,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -1894,6 +1913,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
bool llm_arch_is_recurrent(const llm_arch & arch) { bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) { switch (arch) {
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7: case LLM_ARCH_RWKV7:

View file

@ -49,6 +49,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R, LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2, LLM_ARCH_COHERE2,
@ -174,6 +175,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS, LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE, LLM_KV_WKV_HEAD_SIZE,
@ -293,6 +295,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,

View file

@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
// note: tracking the other way around is not necessary for now // note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true; //seq_cpl[s0][s1] = true;
has_cpl = true;
} }
} }
} }
@ -404,6 +406,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs; return n_outputs;
} }
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() { std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids; return out_ids;
} }
@ -419,6 +425,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() { void llama_batch_allocr::split_reset() {
out_ids.clear(); out_ids.clear();
n_used = 0;
used.clear(); used.clear();
used.resize(get_n_tokens(), false); used.resize(get_n_tokens(), false);
@ -443,6 +451,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx); idxs.push_back(cur_idx);
used[cur_idx] = true; used[cur_idx] = true;
++n_used;
++cur_idx; ++cur_idx;
@ -458,9 +467,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false); return ubatch_add(idxs, idxs.size(), false);
} }
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set; std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch // determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) { if (used[i]) {
@ -477,9 +494,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
} }
} }
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) { if (add) {
cur_seq_set.push_back(seq_set[i]); cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) { if (cur_seq_set.size() > n_ubatch) {
break; break;
} }
@ -528,6 +552,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx); idxs_per_seq[s].push_back(idx);
used[idx] = true; used[idx] = true;
++n_used;
++cur_idx[s]; ++cur_idx[s];
} }
@ -569,6 +594,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx); idxs.push_back(cur_idx);
used[cur_idx] = true; used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) { if (idxs.size() >= n_ubatch) {
break; break;

View file

@ -54,6 +54,7 @@ public:
uint32_t get_n_tokens() const; uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const; uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting // the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids(); std::vector<int32_t> & get_out_ids();
@ -69,7 +70,8 @@ public:
llama_ubatch split_simple(uint32_t n_ubatch); llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets // make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch); // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// sequence-set-wise split - each ubatch contains a single sequence-set // sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch); llama_ubatch split_seq(uint32_t n_ubatch);
@ -112,6 +114,9 @@ private:
using pos_set_t = std::set<llama_pos>; using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>; using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
@ -125,6 +130,8 @@ private:
// batch indices of the output // batch indices of the output
std::vector<int32_t> out_ids; std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch // used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used; std::vector<bool> used;

View file

@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
} }
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
} }
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs(); const int64_t n_rs = mctx->get_recr()->get_n_rs();
@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
} }
} }
void llm_graph_input_one::set_input(const llama_ubatch *) { void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1); GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f; float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto n_kv = inp->mctx->get_attn()->get_n_kv(); const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask, "KQ_mask", -1); inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
//cb(inp_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->kq_mask); ggml_set_input(inp->kq_mask);
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
const auto n_kv = mctx_cur->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask, "KQ_mask", -1); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn(
// optionally store to KV cache // optionally store to KV cache
if (k_cur) { if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
} }
if (v_cur) { if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->cross_kq_mask); ggml_set_input(inp->cross_kq_mask);
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); const auto & v_idxs = inp->get_v_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{ {
const auto n_kv = mctx_cur->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask, "KQ_mask", -1); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
const auto n_kv = mctx_cur->get_swa()->get_n_kv(); const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
ggml_set_input(inp->self_kq_mask_swa); ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs(
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs} // {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states); ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv) // copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
ggml_tensor * llm_graph_context::build_rs( ggml_tensor * llm_graph_context::build_rs(
@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr(); const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

View file

@ -228,8 +228,8 @@ public:
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -249,10 +249,16 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -274,13 +280,23 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -297,8 +313,8 @@ public:
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
const llama_cross * cross = nullptr; const llama_cross * cross = nullptr;
}; };
@ -319,10 +335,16 @@ public:
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
@ -336,7 +358,7 @@ public:
llm_graph_input_one() {} llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default; virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * one = nullptr; // F32 ggml_tensor * one = nullptr; // F32
}; };
@ -424,6 +446,9 @@ struct llm_graph_params {
const llm_graph_cb & cb; const llm_graph_cb & cb;
}; };
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
struct llm_graph_context { struct llm_graph_context {
const llm_arch arch; const llm_arch arch;
@ -663,7 +688,7 @@ struct llm_graph_context {
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp() const; llm_graph_input_rs * build_rs_inp() const;
@ -673,7 +698,7 @@ struct llm_graph_context {
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rs( ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp, llm_graph_input_mem_hybrid * inp,
@ -681,7 +706,7 @@ struct llm_graph_context {
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rwkv_token_shift_load( ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp, llm_graph_input_rs * inp,

View file

@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; // Corresponds to Mamba's conv_states size
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
} }
uint32_t llama_hparams::n_embd_s() const { uint32_t llama_hparams::n_embd_s() const {

View file

@ -114,6 +114,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0; uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0; uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0; uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state space models // for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

View file

@ -117,20 +117,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads_base.empty()) { // failed to find a suitable split
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_base.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// if it fails, try equal split // if it fails, try equal split
@ -139,7 +144,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) { while (true) {
auto ubatch = balloc.split_equal(n_ubatch); auto ubatch = balloc.split_equal(n_ubatch, false);
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
break; break;
@ -148,20 +153,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads_base = kv_base->prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads_base.empty()) { // failed to find a suitable split
break; break;
} }
auto heads_swa = kv_swa->prepare(ubatches); auto sinfos_base = kv_base->prepare(ubatches);
if (heads_swa.empty()) { if (sinfos_base.empty()) {
break; break;
} }
assert(heads_base.size() == heads_swa.size()); auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
@ -224,13 +234,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }

View file

@ -74,6 +74,8 @@ private:
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// used for errors // used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status); llama_kv_cache_unified_iswa_context(llama_memory_status status);
@ -90,8 +92,8 @@ public:
// used to create a batch processing context from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, slot_info_vec_t sinfos_base,
std::vector<uint32_t> heads_swa, slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context(); virtual ~llama_kv_cache_unified_iswa_context();

View file

@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
if (!supports_set_rows) {
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
}
} }
void llama_kv_cache_unified::clear(bool data) { void llama_kv_cache_unified::clear(bool data) {
@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
auto heads = prepare(ubatches); if (balloc.get_n_used() < balloc.get_n_tokens()) {
if (heads.empty()) { // failed to find a suitable split
break;
}
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break; break;
} }
return std::make_unique<llama_kv_cache_unified_context>( return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches)); this, std::move(sinfos), std::move(ubatches));
} while (false); } while (false);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo)); return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
} }
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::ubatch_heads res; llama_kv_cache_unified::slot_info_vec_t res;
struct state { struct state {
uint32_t head_old; // old position of the head, before placing the ubatch uint32_t head_old; // old position of the head, before placing the ubatch
uint32_t head_new; // new position of the head, after placing the ubatch
slot_info sinfo; // slot info for the ubatch
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
}; };
@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
bool success = true; bool success = true;
for (const auto & ubatch : ubatches) { for (const auto & ubatch : ubatches) {
// non-continuous slots require support for ggml_set_rows()
const bool cont = supports_set_rows ? false : true;
// only find a suitable slot for the ubatch. don't modify the cells yet // only find a suitable slot for the ubatch. don't modify the cells yet
const int32_t head_new = find_slot(ubatch); const auto sinfo_new = find_slot(ubatch, cont);
if (head_new < 0) { if (sinfo_new.empty()) {
success = false; success = false;
break; break;
} }
// remeber the position that we found // remeber the position that we found
res.push_back(head_new); res.push_back(sinfo_new);
// store the old state of the cells in the recovery stack // store the old state of the cells in the recovery stack
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
// now emplace the ubatch // now emplace the ubatch
apply_ubatch(head_new, ubatch); apply_ubatch(sinfo_new, ubatch);
} }
// iterate backwards and restore the cells to their original state // iterate backwards and restore the cells to their original state
for (auto it = states.rbegin(); it != states.rend(); ++it) { for (auto it = states.rbegin(); it != states.rend(); ++it) {
cells.set(it->head_new, it->cells); cells.set(it->sinfo.idxs, it->cells);
head = it->head_old; head = it->head_old;
} }
@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated; return updated;
} }
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head; uint32_t head_cur = this->head;
@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (n_tokens > cells.size()) { if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return -1; return { };
} }
if (debug > 0) { if (debug > 0) {
@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
uint32_t n_tested = 0; uint32_t n_tested = 0;
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
// for non-continuous slots, we test the tokens one by one
const uint32_t n_test = cont ? n_tokens : 1;
slot_info res;
auto & idxs = res.idxs;
idxs.reserve(n_tokens);
while (true) { while (true) {
if (head_cur + n_tokens > cells.size()) { if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur; n_tested += cells.size() - head_cur;
head_cur = 0; head_cur = 0;
continue; continue;
} }
bool found = true; for (uint32_t i = 0; i < n_test; i++) {
for (uint32_t i = 0; i < n_tokens; i++) { const auto idx = head_cur;
//const llama_pos pos = ubatch.pos[i]; //const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0]; //const llama_seq_id seq_id = ubatch.seq_id[i][0];
@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
// - (disabled) mask causally, if the sequence is the same as the one we are inserting // - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache // - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos // always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i); bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(head_cur + i) == 1) { if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i); const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask // (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand // note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(head_cur + i, seq_id)) { //if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos; // can_use = pos_cell >= pos;
//} //}
if (!can_use) { if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask // SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
} }
if (!can_use) { head_cur++;
found = false; n_tested++;
head_cur += i + 1;
n_tested += i + 1; if (can_use) {
idxs.push_back(idx);
} else {
break; break;
} }
} }
if (found) { if (idxs.size() == n_tokens) {
break; break;
} }
if (cont) {
idxs.clear();
}
if (n_tested >= cells.size()) { if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return -1; return { };
} }
} }
return head_cur; // we didn't find a suitable slot - return empty result
if (idxs.size() < n_tokens) {
res.clear();
}
return res;
} }
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { assert(ubatch.n_tokens == sinfo.idxs.size());
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
const llama_seq_id seq_id = cells.seq_get(head_cur + i); for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const llama_pos pos = cells.pos_get(head_cur + i); const auto idx = sinfo.idxs.at(i);
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i); cells.rm(idx);
} }
cells.pos_set(head_cur + i, ubatch.pos[i]); cells.pos_set(idx, ubatch.pos[i]);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); cells.seq_add(idx, ubatch.seq_id[i][s]);
} }
} }
@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
} }
// move the head at the end of the slot // move the head at the end of the slot
head = head_cur + ubatch.n_tokens; head = sinfo.idxs.back() + 1;
} }
bool llama_kv_cache_unified::get_can_shift() const { bool llama_kv_cache_unified::get_can_shift() const {
@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
0); 0);
} }
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k; auto * k = layers[ikv].k;
const int64_t n_embd_k_gqa = k->ne[0];
const int64_t n_tokens = k_cur->ne[2]; const int64_t n_tokens = k_cur->ne[2];
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (k_idxs && supports_set_rows) {
return ggml_set_rows(ctx, k, k_cur, k_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * k_view = ggml_view_1d(ctx, k, ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*hparams.n_embd_k_gqa(il), n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
return ggml_cpy(ctx, k_cur, k_view); return ggml_cpy(ctx, k_cur, k_view);
} }
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il); const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v; auto * v = layers[ikv].v;
const int64_t n_embd_v_gqa = v->ne[0];
const int64_t n_tokens = v_cur->ne[2]; const int64_t n_tokens = v_cur->ne[2];
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
if (v_idxs && supports_set_rows) {
if (!v_trans) {
return ggml_set_rows(ctx, v, v_cur, v_idxs);
}
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// v_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
// TODO: fallback to old ggml_cpy() method for backwards compatibility
// will be removed when ggml_set_rows() is adopted by all backends
ggml_tensor * v_view = nullptr; ggml_tensor * v_view = nullptr;
if (!v_trans) { if (!v_trans) {
v_view = ggml_view_1d(ctx, v, v_view = ggml_view_1d(ctx, v,
n_tokens*hparams.n_embd_v_gqa(il), n_tokens*n_embd_v_gqa,
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
} else { } else {
// note: the V cache is transposed when not using flash attention
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
(v->ne[1])*ggml_element_size(v),
(head_cur)*ggml_element_size(v));
v_cur = ggml_transpose(ctx, v_cur); v_cur = ggml_transpose(ctx, v_cur);
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
(v->ne[1] )*ggml_element_size(v),
(sinfo.head())*ggml_element_size(v));
} }
return ggml_cpy(ctx, v_cur, v_view); return ggml_cpy(ctx, v_cur, v_view);
} }
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(k_idxs);
return k_idxs;
}
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
ggml_set_input(v_idxs);
return v_idxs;
}
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens; const uint32_t n_tokens = ubatch->n_tokens;
@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
ubatch.seq_id[i] = &dest_seq_id; ubatch.seq_id[i] = &dest_seq_id;
} }
const auto head_cur = find_slot(ubatch); const auto sinfo = find_slot(ubatch, true);
if (head_cur < 0) { if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
apply_ubatch(head_cur, ubatch); apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();
// keep the head at the old position because we will read the KV data into it in state_read_data() // keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur; head = head_cur;
@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0;
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
sinfos.resize(1);
sinfos[0].idxs.resize(1);
sinfos[0].idxs[0] = 0;
} }
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads, llama_kv_cache_unified::slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
} }
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_context::next() { bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_cur >= ubatches.size()) {
return false; return false;
} }
@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
return true; return true;
} }
kv->apply_ubatch(heads[i_next], ubatches[i_next]); kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
n_kv = kv->get_n_kv(); n_kv = kv->get_n_kv();
head = heads[i_next];
return true; return true;
} }
@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_cur];
} }
uint32_t llama_kv_cache_unified_context::get_n_kv() const { uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
return kv->get_v(ctx, il, n_kv); return kv->get_v(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head); return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
} }
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head); return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_v_idxs(ctx, ubatch);
} }
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst); kv->set_input_k_shift(dst);
} }
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn); kv->set_input_kq_mask(dst, ubatch, causal_attn);
} }

View file

@ -24,8 +24,6 @@ public:
// this callback is used to filter out layers that should not be included in the cache // this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>; using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info { struct defrag_info {
bool empty() const { bool empty() const {
return ids.empty(); return ids.empty();
@ -37,6 +35,32 @@ public:
std::vector<uint32_t> ids; std::vector<uint32_t> ids;
}; };
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
idx_vec_t idxs;
uint32_t head() const {
return idxs.at(0);
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified( llama_kv_cache_unified(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, layer_filter_cb && filter,
@ -102,30 +126,37 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
// //
// preparation API // preparation API
// //
// find places for the provided ubatches in the cache, returns the head locations // find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure // return empty vector on failure
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches); slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch // find a slot of kv cells that can hold the ubatch
// return -1 on failure to find a contiguous slot of kv cells // if cont == true, then the slot must be continuous
int32_t find_slot(const llama_ubatch & ubatch) const; // return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
// //
// set_input API // input API
// //
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const; void set_input_k_shift (ggml_tensor * dst) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -157,8 +188,13 @@ private:
// SWA // SWA
const uint32_t n_swa = 0; const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0; int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
int supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
@ -211,7 +247,7 @@ private:
class llama_kv_cache_unified_context : public llama_memory_context_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
@ -231,7 +267,7 @@ public:
// used to create a batch procesing context from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_context( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
ubatch_heads heads, slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context(); virtual ~llama_kv_cache_unified_context();
@ -257,11 +293,16 @@ public:
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location // store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
void set_input_k_shift(ggml_tensor * dst) const; ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -283,10 +324,10 @@ private:
// batch processing context // batch processing context
// //
// the index of the next ubatch to process // the index of the cur ubatch to process
size_t i_next = 0; size_t i_cur = 0;
ubatch_heads heads; slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
@ -297,7 +338,4 @@ private:
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears // as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv; int32_t n_kv;
// the beginning of the current slot in which the ubatch will be inserted
int32_t head;
}; };

View file

@ -105,10 +105,30 @@ public:
res.resize(n); res.resize(n);
for (uint32_t j = 0; j < n; ++j) { for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j]; const auto idx = i + j;
res.seq[j] = seq[i + j];
assert(shift[i + j] == 0); res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
} }
return res; return res;
@ -119,26 +139,58 @@ public:
assert(i + other.pos.size() <= pos.size()); assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) { const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j); used.insert(i + j);
} }
if (pos[i + j] != -1 && other.pos[j] == -1) { if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j); used.erase(i + j);
} }
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_rm(i + j); seq_pos_rm(i + j);
} }
pos[i + j] = other.pos[j]; pos[idx] = other.pos[j];
seq[i + j] = other.seq[j]; seq[idx] = other.seq[j];
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_add(i + j); seq_pos_add(i + j);
} }
assert(shift[i + j] == 0); assert(shift[idx] == 0);
}
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
} }
} }

View file

@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = balloc.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch, false);
} }
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT ubatches.push_back(std::move(ubatch)); // NOLINT
} }
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first // prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) { if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point? // TODO: will the recurrent cache be in an undefined context at this point?
@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }

View file

@ -92,6 +92,8 @@ private:
class llama_memory_hybrid_context : public llama_memory_context_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// init failure // init failure
explicit llama_memory_hybrid_context(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
@ -107,7 +109,7 @@ public:
// init success // init success
llama_memory_hybrid_context( llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, slot_info_vec_t sinfos_attn,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default; ~llama_memory_hybrid_context() = default;

View file

@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = balloc.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch, false);
} }
if (ubatch.n_tokens == 0) { if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break; break;
} }

View file

@ -213,23 +213,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break; } break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
{ {
// FIXME const int64_t n_seq_tokens = 512;
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); const int64_t n_seqs = 3;
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs);
op_tensor = ggml_ssm_conv(ctx, conv_x, w); op_tensor = ggml_ssm_conv(ctx, conv_x, w);
} break; } break;
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
{ {
// FIXME // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2
const int64_t d_state = w->ne[0]; const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
const int64_t d_inner = w->ne[1]; const int64_t n_head = w->ne[1];
const int64_t head_dim = hparams.ssm_d_inner / n_head;
const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
const int64_t n_seq_tokens = 512; const int64_t n_seq_tokens = 512;
const int64_t n_seqs = 1; const int64_t n_seqs = 3;
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs);
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids);
} break; } break;
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
{ {
@ -1086,6 +1090,38 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_MAMBA2:
{
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24:
switch (hparams.n_embd) {
case 768: type = LLM_TYPE_SMALL; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 48:
switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_MEDIUM; break;
case 1536: type = LLM_TYPE_LARGE; break;
case 2048: type = LLM_TYPE_XL; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 64:
switch (hparams.n_embd) {
case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -3216,6 +3252,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
}
} break;
case LLM_ARCH_MAMBA2:
{
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
// only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
{
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed, duplicated to allow offloading
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
// out_proj // out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
} }
@ -4727,10 +4811,14 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
}
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) { if (!classifier_labels.empty()) {
@ -9765,9 +9853,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
}; };
struct llm_build_mamba : public llm_graph_context { struct llm_build_mamba : public llm_graph_context {
const llama_model & model; llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
ggml_tensor * cur; ggml_tensor * cur;
ggml_tensor * inpL; ggml_tensor * inpL;
@ -9785,7 +9871,11 @@ struct llm_build_mamba : public llm_graph_context {
LLM_NORM_RMS, il); LLM_NORM_RMS, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); if (model.arch == LLM_ARCH_MAMBA2) {
cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
} else {
cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
}
if (il == n_layer - 1 && inp_out_ids) { if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
@ -9819,11 +9909,11 @@ struct llm_build_mamba : public llm_graph_context {
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
// TODO: split
ggml_tensor * build_mamba_layer( ggml_tensor * build_mamba_layer(
llm_graph_input_rs * inp, llm_graph_input_rs * inp,
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * cur, ggml_tensor * cur,
const llama_model & model,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
@ -9834,6 +9924,8 @@ struct llm_build_mamba : public llm_graph_context {
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank; const int64_t dt_rank = hparams.ssm_dt_rank;
const int64_t n_head = d_inner;
const int64_t head_dim = 1;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
@ -9849,15 +9941,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// (ab)using the KV cache to store the states ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
ggml_tensor * conv = build_rs(
inp, gf, conv_states_all,
hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
ggml_tensor * ssm = build_rs(
inp, gf, ssm_states_all,
hparams.n_embd_s(), n_seqs);
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9906,8 +9991,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
// split // split
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) { if (ssm_dt_b_c_rms) {
@ -9920,23 +10005,36 @@ struct llm_build_mamba : public llm_graph_context {
dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = build_lora_mm(model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
cur = x;
x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = model.layers[il].ssm_a;
// use the states and the indices provided by build_recurrent_state
// (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
// Custom operator to optimize the parallel associative scan // Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper. // as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states // store last states
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
ggml_cpy(ctx0, ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@ -9945,7 +10043,136 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
//cb(cur, "mamba_out", il); // cb(cur, "mamba_out", il);
return cur;
}
ggml_tensor * build_mamba2_layer(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_model & model,
const llama_ubatch & ubatch,
int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto kv_head = mctx_cur->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_head;
const int64_t n_group = hparams.ssm_n_group;
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
// split the above in three
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
// conv
{
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
// copy last (d_conv - 1) columns back into the state cache
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv,
ggml_view_1d(ctx0, conv_states_all,
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
// over d_conv columns at each stride in the 3rd dimension,
// then element-wise multiply that with the conv1d weight,
// then sum the elements of each row,
// (the last two steps are a dot product over rows (also doable with mul_mat))
// then permute away the ne[0] dimension,
// and then you're left with the resulting x tensor.
// For simultaneous sequences, all sequences need to have the same length.
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
// bias
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
xBC = ggml_silu(ctx0, xBC);
}
// ssm
{
// These correspond to V K Q in SSM/attention duality
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
// {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
ggml_tensor * A = model.layers[il].ssm_a;
// use the states and the indices provided by build_recurrent_state
// (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
// grouped RMS norm
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(model.layers[il].ssm_out, y);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
// cb(cur, "mamba_out", il);
return cur; return cur;
} }
@ -14768,6 +14995,7 @@ llm_graph_result_ptr llama_model::build_graph(
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf); llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
{ {
llm = std::make_unique<llm_build_mamba>(*this, params, gf); llm = std::make_unique<llm_build_mamba>(*this, params, gf);
} break; } break;
@ -15028,6 +15256,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM: case LLM_ARCH_BLOOM:
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5: case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER: case LLM_ARCH_T5ENCODER:

View file

@ -172,6 +172,7 @@ struct llama_layer {
struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * ffn_sub_norm = nullptr;
struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_cross = nullptr;
struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * attn_norm_enc = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
// attention // attention
struct ggml_tensor * wq = nullptr; struct ggml_tensor * wq = nullptr;

View file

@ -1430,8 +1430,7 @@ struct clip_graph {
ggml_tensor * x = embeddings; ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings); embeddings = ggml_swiglu_split(ctx0, embeddings, x);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
} }
// arrangement of BOI/EOI token embeddings // arrangement of BOI/EOI token embeddings
@ -1527,15 +1526,8 @@ struct clip_graph {
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu // swiglu
{
int64_t split_point = cur->ne[0] / 2;
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
x1 = ggml_silu(ctx0, x1); cur = ggml_swiglu_swapped(ctx0, cur);
cur = ggml_mul(ctx0, x0, x1);
}
// mid-norm // mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6); cur = ggml_rms_norm(ctx0, cur, 1e-6);
@ -1794,33 +1786,40 @@ private:
cur = tmp; cur = tmp;
} }
// we only support parallel ffn for now
switch (type_op) { switch (type_op) {
case FFN_SILU: case FFN_SILU:
{ if (gate) {
cur = ggml_swiglu_split(ctx0, cur, tmp);
cb(cur, "ffn_swiglu", il);
} else {
cur = ggml_silu(ctx0, cur); cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il); cb(cur, "ffn_silu", il);
} break; } break;
case FFN_GELU: case FFN_GELU:
{ if (gate) {
cur = ggml_geglu_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu", il);
} else {
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il); cb(cur, "ffn_gelu", il);
} break; } break;
case FFN_GELU_ERF: case FFN_GELU_ERF:
{ if (gate) {
cur = ggml_geglu_erf_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu_erf", il);
} else {
cur = ggml_gelu_erf(ctx0, cur); cur = ggml_gelu_erf(ctx0, cur);
cb(cur, "ggml_gelu_erf", il); cb(cur, "ffn_gelu_erf", il);
} break; } break;
case FFN_GELU_QUICK: case FFN_GELU_QUICK:
{
cur = ggml_gelu_quick(ctx0, cur);
cb(cur, "ffn_relu", il);
} break;
}
// we only support parallel ffn for now
if (gate) { if (gate) {
cur = ggml_mul(ctx0, cur, tmp); cur = ggml_geglu_quick_split(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il); cb(cur, "ffn_geglu_quick", il);
} else {
cur = ggml_gelu_quick(ctx0, cur);
cb(cur, "ffn_gelu_quick", il);
} break;
} }
if (down) { if (down) {