diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c2c55166e..dd80a4a05 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model): ] def set_vocab(self): - with open(self.dir_model / "chat_template.jinja") as f: - # quick hack to make sure chat template is added - self.gguf_writer.add_chat_template(f.read()) super().set_vocab() def set_gguf_parameters(self): @@ -4781,6 +4778,14 @@ class ARwkv7Model(Rwkv7Model): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4855,6 +4860,100 @@ class MambaModel(TextModel): return [(new_name, data_torch)] +@ModelBase.register("Mamba2ForCausalLM") +class Mamba2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + data_torch = data_torch.reshape((*data_torch.shape, 1)) + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + data_torch = data_torch.reshape((n_group, d_inner // n_group)) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R @@ -6615,12 +6714,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st # maybe we should fallback to text model's arch in that case, since not many models have both text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h deleted file mode 100644 index 154aa56a7..000000000 --- a/ggml/include/ggml-kompute.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "ggml.h" -#include "ggml-backend.h" - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -#define GGML_KOMPUTE_MAX_DEVICES 16 - -struct ggml_vk_device { - int index; - int type; // same as VkPhysicalDeviceType - size_t heapSize; - const char * name; - const char * vendor; - int subgroupSize; - uint64_t bufferAlignment; - uint64_t maxAlloc; -}; - -struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); -bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); -bool ggml_vk_has_vulkan(void); -bool ggml_vk_has_device(void); -struct ggml_vk_device ggml_vk_current_device(void); - -// -// backend API -// - -// forward declaration -typedef struct ggml_backend * ggml_backend_t; - -GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device); - -GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend); - -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); - -GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void); - -#ifdef __cplusplus -} -#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0dc4035ef..5202b4de7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -563,6 +563,8 @@ extern "C" { GGML_GLU_OP_REGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU_ERF, + GGML_GLU_OP_GEGLU_QUICK, GGML_GLU_OP_COUNT, }; @@ -659,6 +661,9 @@ extern "C" { // misc + GGML_API const char * ggml_version(void); + GGML_API const char * ggml_commit(void); + GGML_API void ggml_time_init(void); // call this once at the beginning of the program GGML_API int64_t ggml_time_ms(void); GGML_API int64_t ggml_time_us(void); @@ -1157,6 +1162,22 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + // A: n columns, r rows, // B: n columns, r rows, GGML_API struct ggml_tensor * ggml_glu_split( @@ -1180,6 +1201,16 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, @@ -1523,8 +1554,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // a [ne0, ne01, ne02, ne03] + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional + // + // broadcast: + // ne02 % ne12 == 0 + // ne03 % ne13 == 0 + // // fused soft_max(a*scale + mask*(ALiBi slope)) - // mask is optional // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, @@ -1987,11 +2024,17 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd_k, n_batch, n_head, 1] - // k: [n_embd_k, n_kv, n_head_kv, 1] - // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! + // + // broadcast: + // n_head % n_head_kv == 0 + // n_head % ne32 == 0 + // ne3 % ne33 == 0 + // GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2030,7 +2073,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 342f3f474..de1747fdb 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -61,10 +61,6 @@ #include "ggml-cann.h" #endif -#ifdef GGML_USE_KOMPUTE -#include "ggml-kompute.h" -#endif - // disable C++17 deprecation warning for std::codecvt_utf8 #if defined(__clang__) # pragma clang diagnostic push @@ -189,9 +185,6 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif -#ifdef GGML_USE_KOMPUTE - register_backend(ggml_backend_kompute_reg()); -#endif #ifdef GGML_USE_CPU register_backend(ggml_backend_cpu_reg()); #endif @@ -576,7 +569,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path); - ggml_backend_load_best("kompute", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d59f5bb0..c47c27185 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2186,6 +2186,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index dd83efde7..aaeee614a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_geglu_erf + +static void ggml_compute_forward_geglu_erf_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_erf_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_erf_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu_quick + +static void ggml_compute_forward_geglu_quick_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_quick_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_quick_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -5232,14 +5518,17 @@ static void ggml_compute_forward_soft_max_f32( memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // TODO: handle transposed/permuted matrices - const int ith = params->ith; const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS - //const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -5249,68 +5538,66 @@ static void ggml_compute_forward_soft_max_f32( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif } } - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); - - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif } } @@ -7766,7 +8053,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; + ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; @@ -7798,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(VKQ32, 0, DV*sizeof(float)); } - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -8336,120 +8623,210 @@ void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // s - const ggml_tensor * src1 = dst->src[1]; // x - const ggml_tensor * src2 = dst->src[2]; // dt - const ggml_tensor * src3 = dst->src[3]; // A - const ggml_tensor * src4 = dst->src[4]; // B - const ggml_tensor * src5 = dst->src[5]; // C + const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} + const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // heads per thread + const int dh = (nh + nth - 1)/nth; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); - #ifdef __ARM_FEATURE_SVE - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + const int32_t * ids = (const int32_t *) src6->data; - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); - svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); - svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + for (int i2 = 0; i2 < nt; ++i2) { + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - for (int64_t k = 0; k < nc; k += svcntw()) { - svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); - svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); + if (src3->ne[0] == 1) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop - svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); - t1 = exp_ps_sve(svptrue_b32(), t1); - svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); - vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); - r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; + float sumf = 0.0f; +#if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; - GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); - } - y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); - } - } - } + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); #else - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + const int np = (nc & ~(GGML_F32_STEP - 1)); - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + #endif +#else + const int np = 0; +#endif + // d_state + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; +#if defined(__ARM_FEATURE_SVE) + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + // d_state + // TODO: what happens when (d_state % svcntw()) != 0? + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(t2, vs0, t1); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[ii*nc + k], vs0); + } + y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector); +#else + float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; +#endif } - y[i1] = sumf; } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } - #endif + } } void ggml_compute_forward_ssm_scan( @@ -8688,6 +9065,14 @@ void ggml_compute_forward_glu( { ggml_compute_forward_swiglu(params, dst); } break; + case GGML_GLU_OP_GEGLU_ERF: + { + ggml_compute_forward_geglu_erf(params, dst); + } break; + case GGML_GLU_OP_GEGLU_QUICK: + { + ggml_compute_forward_geglu_quick(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index b68ac0dd6..b4ad68c9f 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index ed5d7aefc..a8156011e 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = 0; i < np; i += ggml_f32_step) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); + sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); + sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8); } // leftovers // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop @@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only if (np2 < n) { diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index d5507d756..1f5857a23 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); } @@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_ } } +inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + float xi = x[i]; + y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i]; + } +} + +inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float xi = GGML_CPU_FP16_TO_FP32(x[i]); + float gi = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi); + } +} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i]; + } +} +#else +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]) * g[i]; + } +} +#endif + +inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a28ddbc15..8fba68f19 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -179,6 +179,20 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) +#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 0ce4afbb2..0c8b08197 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32), smpbo); cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32), smpbo); cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfab2b5eb..075f14a49 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -851,7 +853,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e230f6d49..709589854 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 5032c4f71..f86b73248 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); @@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 32673adb5..124d5d3e8 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index d44af148a..8cb75d5b0 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16( K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 953967917..c22baf417 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32( Q += nb02* blockIdx.z + nb01*ic0; K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 0dd0f12bf..ca78eb7e1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16( constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5d9503ee6..04d55fe8b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2319,6 +2319,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_GLU_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_cuda_op_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_cuda_op_geglu_quick(ctx, dst); + break; default: return false; } @@ -3121,6 +3127,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]); default: return false; @@ -3326,12 +3334,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: return true; + case GGML_OP_SSM_SCAN: { + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports d_state == 128 && d_head % 16 == 0) + return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; + } else { + // Mamba + // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1) + return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1; + } + } + case GGML_OP_SSM_CONV: { + // assumes d_inner % threads == 0 + return op->src[0]->ne[1] % 128 == 0; + } case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: return true; case GGML_OP_SOFT_MAX_BACK: { @@ -3380,6 +3402,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } + // TODO: support broadcast + // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but + // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index bae1a58c3..4230dd814 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3017,14 +3017,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index aac6e0999..14543e978 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -2,6 +2,7 @@ #include "ggml.h" #include "softmax.cuh" #include +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. #ifdef __clang__ @@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32(half val) { #endif // __clang__ template static __global__ void soft_max_f32( - const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, - const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const float * x, const T * mask, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; const int tid = threadIdx.x; - const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; x += int64_t(rowx)*ncols; - mask += int64_t(rowy)*ncols * (mask != nullptr); + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); dst += int64_t(rowx)*ncols; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; @@ -38,7 +70,7 @@ static __global__ void soft_max_f32( const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication @@ -55,7 +87,7 @@ static __global__ void soft_max_f32( break; } - const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32( } } +template +static void launch_soft_max_kernels(const float * x, const T * mask, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) +{ + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>> + (x, mask, dst, p); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>>(x, mask, dst, p); +} + + template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); - const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { - switch (ncols_x) { - case 32: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 64: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 128: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 256: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 512: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 1024: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 2048: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 4096: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - default: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - } + + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, params); } } @@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = src0->ne[0]; + float scale = 1.0f; float max_bias = 0.0f; @@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2d34b8360..dc3b1a9a8 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -4,16 +4,15 @@ template __global__ void __launch_bounds__(splitD, 2) ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, - const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * __restrict__ dst, const int64_t L) { - GGML_UNUSED(src1_nb0); - GGML_UNUSED(src2_nb0); + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t d_inner, const int64_t L) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int bidx = blockIdx.x; // split along B - const int bidy = blockIdx.y; // split along D + const int bidx = blockIdx.x; // split along B (sequences) + const int bidy = blockIdx.y; // split along D (d_inner) const int tid = threadIdx.x; const int wid = tid / 32; const int wtid = tid % 32; @@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2) float * smem_A = smem; float * smem_s0 = smem_A + splitD * stride_sA; - const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); - const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); + const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); - const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); - float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); - float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3)); + const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3)); + float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float)); + float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2); - const int stride_s0 = src0_nb1 / sizeof(float); - const int stride_x = src1_nb1 / sizeof(float); + const int stride_s0 = src0_nb2 / sizeof(float); + const int stride_x = src1_nb2 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float); const int stride_A = src3_nb1 / sizeof(float); - const int stride_B = src4_nb1 / sizeof(float); - const int stride_C = src5_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); const int stride_s = stride_s0; - const int stride_y = stride_x; + const int stride_y = d_inner; // can N not be 16? for example 32? if (N == 16) { @@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2) } } +// assumes as many threads as d_state +template +__global__ void __launch_bounds__(d_state, 1) + ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + + const int head_idx = (blockIdx.x * splitH) / d_head; + const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; + + const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + + const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); + const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; + float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + // strides across n_seq_tokens + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[splitH]; + // for the parallel accumulation + __shared__ float stateC[splitH * d_state]; + +#pragma unroll + for (int j = 0; j < splitH; j++) { + state[j] = s0_block[j * d_state + threadIdx.x]; + } + + for (int64_t i = 0; i < n_tok; i++) { + // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements + // TODO: only calculate B and C once per head group + // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. + float dt_soft_plus = dt_block[i * stride_dt]; + if (dt_soft_plus <= 20.0f) { + dt_soft_plus = log1pf(expf(dt_soft_plus)); + } + const float dA = expf(dt_soft_plus * A_block[0]); + const float B = B_block[i * stride_B + threadIdx.x]; + const float C = C_block[i * stride_C + threadIdx.x]; + + // across d_head +#pragma unroll + for (int j = 0; j < splitH; j++) { + const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; + + state[j] = (state[j] * dA) + (B * x_dt); + + stateC[j * d_state + threadIdx.x] = state[j] * C; + } + + __syncthreads(); + + // parallel accumulation for stateC + // TODO: simplify + { + static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); + static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); + + // reduce until w matches the warp size + // TODO: does this work even when the physical warp size is 64? +#pragma unroll + for (int w = d_state; w > WARP_SIZE; w >>= 1) { + // (assuming there are d_state threads) +#pragma unroll + for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { + // TODO: check for bank conflicts + const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); + stateC[k] += stateC[k + (w >> 1)]; + + } + __syncthreads(); + } + + static_assert(splitH >= d_state / WARP_SIZE); + +#pragma unroll + for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { + float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; + y = warp_reduce_sum(y); + + // store the above accumulations + if (threadIdx.x % WARP_SIZE == 0) { + const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); + y_block[i * stride_y + k] = y; + } + } + } + } + + // write back the state +#pragma unroll + for (int j = 0; j < splitH; j++) { + s_block[j * d_state + threadIdx.x] = state[j]; + } +} + static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, - const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, - const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { const int threads = 128; - // todo: consider D cannot be divided,does this situation exist? - GGML_ASSERT(D % threads == 0); - const dim3 blocks(B, (D + threads - 1) / threads, 1); - const int smem_size = (threads * (N + 1) * 2) * sizeof(float); - if (N == 16) { - ssm_scan_f32<128, 16><<>>( - src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, - src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + if (src3_nb1 == sizeof(float)) { + // Mamba-2 + if (d_state == 128) { + GGML_ASSERT(d_state % threads == 0); + // NOTE: can be any power of two between 4 and 64 + const int splitH = 16; + GGML_ASSERT(head_dim % splitH == 0); + const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); + ssm_scan_f32_group<16, 128><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=128."); + } } else { - GGML_ABORT("doesn't support N!=16."); + // Mamba-1 + GGML_ASSERT(n_head % threads == 0); + GGML_ASSERT(head_dim == 1); + GGML_ASSERT(n_group == 1); + const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); + const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); + if (d_state == 16) { + ssm_scan_f32<128, 16><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=16."); + } } } @@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - - // const int64_t d_state = src0->ne[0]; - // const int64_t d_inner = src0->ne[1]; - // const int64_t l = src1->ne[1]; - // const int64_t b = src0->ne[2]; + const struct ggml_tensor * src6 = dst->src[6]; // ids const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nr = src0->ne[1]; // head_dim or 1 + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; // n_group + const int64_t n_t = src1->ne[2]; // number of tokens per sequence + const int64_t n_s = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; @@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * src3_d = (const float *) src3->data; const float * src4_d = (const float *) src4->data; const float * src5_d = (const float *) src5->data; + const int32_t * src6_d = (const int32_t *) src6->data; float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], - src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], - src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d, + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ba3c0f137..f9c7b83c4 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary_gated(ctx, dst); } +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 9094f1d0b..289d690e5 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7a9aab316..752d55c21 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -229,7 +229,11 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne32; + int32_t ne33; uint64_t nb31; + uint64_t nb32; + uint64_t nb33; int32_t ne1; int32_t ne2; float scale; @@ -461,9 +465,21 @@ typedef struct { } ggml_metal_kargs_sum_rows; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float scale; float max_bias; float m0; @@ -499,26 +515,25 @@ typedef struct { typedef struct { int64_t d_state; int64_t d_inner; + int64_t n_head; + int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - uint64_t nb00; uint64_t nb01; uint64_t nb02; - uint64_t nb10; + uint64_t nb03; uint64_t nb11; uint64_t nb12; uint64_t nb13; - uint64_t nb20; uint64_t nb21; uint64_t nb22; - uint64_t nb30; uint64_t nb31; - uint64_t nb40; uint64_t nb41; uint64_t nb42; - uint64_t nb50; + uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t nb53; } ggml_metal_kargs_ssm_scan; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index ceae3c7c3..d7f260a00 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -217,6 +217,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, @@ -529,6 +530,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_REGLU, GGML_METAL_KERNEL_TYPE_GEGLU, GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_GEGLU_ERF, + GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, @@ -1196,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); @@ -1508,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); @@ -1691,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -1725,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); @@ -2454,6 +2462,12 @@ static bool ggml_metal_encode_node( case GGML_GLU_OP_SWIGLU: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; break; + case GGML_GLU_OP_GEGLU_ERF: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; + break; + case GGML_GLU_OP_GEGLU_QUICK: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline; + break; default: GGML_ABORT("fatal error"); } @@ -2644,10 +2658,7 @@ static bool ggml_metal_encode_node( memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -2707,6 +2718,18 @@ static bool ggml_metal_encode_node( /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, @@ -2726,7 +2749,7 @@ static bool ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { @@ -2800,71 +2823,91 @@ static bool ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t d_state = ne00; const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + id pipeline = nil; + + if (ne30 == 1) { + // Mamba-2 + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb20 =*/ nb20, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb30 =*/ nb30, - /*.nb31 =*/ nb31, - /*.nb40 =*/ nb40, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb50 =*/ nb50, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; @@ -2874,10 +2917,17 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&args length:sizeof(args) atIndex:7]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + [encoder setBytes:&args length:sizeof(args) atIndex:8]; - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_RWKV_WKV6: { @@ -4979,7 +5029,11 @@ static bool ggml_metal_encode_node( /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index dac45c7a9..22240bab4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } void quantize_q4_0(device const float * src, device block_q4_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max float max = 0.0f; @@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) { } void quantize_q5_0(device const float * src, device block_q5_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max float max = 0.0f; @@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re } void quantize_q8_0(device const float * src, device block_q8_0 & dst) { +#pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max for (int j = 0; j < QK8_0; j++) { @@ -1258,6 +1261,50 @@ kernel void kernel_swiglu( } } +kernel void kernel_geglu_erf( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f+erf_approx(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_quick( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -1320,24 +1367,28 @@ kernel void kernel_soft_max( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1348,13 +1399,13 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1373,7 +1424,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; @@ -1385,7 +1436,7 @@ kernel void kernel_soft_max( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1404,7 +1455,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { pdst[i00] *= inv_sum; } } @@ -1416,23 +1467,27 @@ kernel void kernel_soft_max_4( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1443,14 +1498,14 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1469,7 +1524,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; @@ -1483,7 +1538,7 @@ kernel void kernel_soft_max_4( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1502,7 +1557,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { pdst4[i00] *= inv_sum; } } @@ -1588,7 +1643,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1596,46 +1651,119 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, device float * dst, constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); const int64_t nc = args.d_state; - // const int64_t nr = args.d_inner; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - // const int64_t n_s = args.n_seqs; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); - device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } y[0] = sumf; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + + const int64_t nc = args.d_state; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; + const int64_t n_t = args.n_seq_tokens; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = exp(dt_soft_plus * A[0]); + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + + // recurse + s0 = s; } } @@ -3776,7 +3904,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); const float m = pm[ic + tiisg]; @@ -4262,7 +4390,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/glu.cl b/ggml/src/ggml-opencl/kernels/glu.cl index ba861d8b1..7cca16e6a 100644 --- a/ggml/src/ggml-opencl/kernels/glu.cl +++ b/ggml/src/ggml-opencl/kernels/glu.cl @@ -1,7 +1,9 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f +#define SQRT_2_INV 0.70710678118654752440084436210484f //------------------------------------------------------------------------------ // geglu @@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16( dst_row[i0] = silu*x1; } } + +//------------------------------------------------------------------------------ +// geglu_erf +//------------------------------------------------------------------------------ +kernel void kernel_geglu_erf( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_erf_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +//------------------------------------------------------------------------------ +// geglu_quick +//------------------------------------------------------------------------------ +kernel void kernel_geglu_quick( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + +kernel void kernel_geglu_quick_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99aee9655..ea484845c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -240,6 +240,21 @@ enum vk_device_architecture { INTEL_XE2, }; +// HSK x HSV +enum FaHeadSizes { + FA_HEAD_SIZE_64, + FA_HEAD_SIZE_80, + FA_HEAD_SIZE_96, + FA_HEAD_SIZE_112, + FA_HEAD_SIZE_128, + FA_HEAD_SIZE_192, + FA_HEAD_SIZE_192_128, + FA_HEAD_SIZE_256, + FA_HEAD_SIZE_576_512, + FA_HEAD_SIZE_UNSUPPORTED, + FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, +}; + static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { vk::PhysicalDeviceProperties props = device.getProperties(); @@ -457,6 +472,8 @@ struct vk_device_struct { vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; @@ -483,26 +500,11 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} - vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_split_k_reduce; @@ -649,6 +651,7 @@ struct vk_flash_attn_push_constants { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; uint32_t nb01; uint32_t nb02; @@ -659,7 +662,6 @@ struct vk_flash_attn_push_constants { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; @@ -674,6 +676,7 @@ struct vk_flash_attn_push_constants { uint32_t split_kv; uint32_t k_num; }; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); struct vk_op_push_constants { uint32_t KX; @@ -772,6 +775,14 @@ struct vk_op_rope_push_constants { struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; float scale; float max_bias; float m0; @@ -1010,7 +1021,7 @@ struct ggml_backend_vk_context { // number of additional consecutive nodes that are being fused with the // node currently being processed - uint32_t num_additional_fused_ops {}; + int num_additional_fused_ops {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -1706,6 +1717,35 @@ enum FaCodePath { FA_COOPMAT2, }; +static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { + if (hsk != 192 && hsk != 576 && hsk != hsv) { + return FA_HEAD_SIZE_UNSUPPORTED; + } + switch (hsk) { + case 64: return FA_HEAD_SIZE_64; + case 80: return FA_HEAD_SIZE_80; + case 96: return FA_HEAD_SIZE_96; + case 112: return FA_HEAD_SIZE_112; + case 128: return FA_HEAD_SIZE_128; + case 192: + if (hsv == 192) { + return FA_HEAD_SIZE_192; + } else if (hsv == 128) { + return FA_HEAD_SIZE_192_128; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + case 256: return FA_HEAD_SIZE_256; + case 576: + if (hsv == 512) { + return FA_HEAD_SIZE_576_512; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + default: return FA_HEAD_SIZE_UNSUPPORTED; + } +} + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; @@ -1726,8 +1766,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -1751,7 +1792,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_ } // small cols to reduce register count - if (ggml_is_quantized(type) || D == 256) { + if (ggml_is_quantized(type) || hsk >= 256) { return {64, 32}; } return {64, 64}; @@ -2044,19 +2085,21 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) + // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. + const uint32_t D = (hsk|hsv); uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -2065,26 +2108,29 @@ static void ggml_vk_load_shaders(vk_device& device) { // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -2793,6 +2839,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_GLU(geglu) CREATE_GLU(reglu) CREATE_GLU(swiglu) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) #undef CREATE_GLU ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -3703,7 +3751,6 @@ static void ggml_vk_instance_init() { } - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan @@ -6017,24 +6064,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t Br = scalar_flash_attention_num_large_rows; const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + + const uint32_t masksh = Bc * Br * sizeof(float); + + const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = coopmat1_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * acctype; - const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4; - const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = D / 4 + 2; + const uint32_t kshstride = hsk / 4 + 2; const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t slope = Br * sizeof(float); @@ -6042,7 +6112,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -6064,13 +6134,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const uint32_t nem1 = mask ? mask->ne[1] : 0; - const uint32_t nbm1 = mask ? mask->nb[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; - const uint32_t D = neq0; + const uint32_t HSK = nek0; + const uint32_t HSV = nev0; uint32_t N = neq1; const uint32_t KV = nek1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == HSV); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -6078,12 +6149,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == HSK); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); GGML_ASSERT(nev1 == nek1); @@ -6104,7 +6172,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); if (!coopmat_shape_supported || !coopmat_shmem_supported) { path = FA_SCALAR; @@ -6157,47 +6225,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx path = FA_SCALAR; } + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; + } + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]); + switch (path) { case FA_SCALAR: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0]; break; case FA_COOPMAT1: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0]; break; case FA_COOPMAT2: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0]; break; default: GGML_ASSERT(0); @@ -6227,7 +6273,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Try to use split_k when KV is large enough to be worth the overhead if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. - split_k = ctx->device->shader_core_count * 2 / workgroups_y; + split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. @@ -6237,9 +6283,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) - // and the per-row m and L values (ne1 rows). - const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; if (split_k_size > ctx->device->max_memory_allocation_size) { GGML_ABORT("Requested preallocation size is too large"); } @@ -6331,11 +6377,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, + nem1, nem2, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, - nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; @@ -6358,13 +6403,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); - const std::array pc2 = { D, (uint32_t)ne1, split_k }; + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2, { (uint32_t)ne1, 1, 1 }); + pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -6558,6 +6603,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; case GGML_GLU_OP_SWIGLU: return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; default: break; } @@ -7690,7 +7739,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); const uint32_t nrows_y = (uint32_t)src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -7699,6 +7754,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, scale, max_bias, m0, m1, n_head_log2, @@ -8893,6 +8951,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: break; default: return false; @@ -9140,6 +9200,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); break; default: @@ -9358,6 +9420,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: buf = tensor->buffer; break; default: @@ -10168,6 +10232,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -10248,19 +10314,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; - switch (op->src[0]->ne[0]) { - case 64: - case 80: - case 96: - case 112: - case 128: - case 256: - break; - default: - return false; - } - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet + FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]); + if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { return false; } if (op->src[0]->type != GGML_TYPE_F32) { @@ -10272,6 +10327,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } + // TODO: support broadcast + // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but + // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 + if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) { + return false; + } // It's straightforward to support different K/V dequant, but would // significantly increase the number of pipelines if (op->src[1]->type != op->src[2]->type) { @@ -10340,6 +10401,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } } break; + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 + return false; + } break; case GGML_OP_CONT: case GGML_OP_CPY: case GGML_OP_DUP: @@ -10430,6 +10497,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ce230a8f7..788a5e065 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -11,7 +11,8 @@ #include "types.comp" #include "flash_attn_base.comp" -const uint32_t D_per_thread = D / D_split; +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - uint32_t offset = (iq2 + r) * D + c; + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } @@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; -shared vec4 Qf[Br][D / 4]; +shared vec4 Qf[Br][HSK / 4]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -53,18 +54,18 @@ void main() { uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t r = (idx + tid) / (D / 4); - if (r < Br && d < D / 4 && + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && i * Br + r < N) { Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; } } barrier(); - vec4 Of[Br][D_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + vec4 Of[Br][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = vec4(0.0); } @@ -99,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -112,7 +117,7 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -150,7 +155,7 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); @@ -191,14 +196,14 @@ void main() { Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = eMf[r] * Of[r][d]; } } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -255,7 +260,7 @@ void main() { Lf[r] = tmpsh[d_tid]; barrier(); - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = eMf * Of[r][d]; tmpshv4[tid] = Of[r][d]; @@ -277,11 +282,11 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -289,7 +294,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -305,18 +310,18 @@ void main() { Lfrcp[r] = 1.0 / Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -326,9 +331,9 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (i * Br + r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 61d90e2d8..6609f0bad 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = 0; -layout (constant_id = 5) const uint32_t D_split = 16; - +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; layout (push_constant) uniform parameter { uint32_t N; @@ -24,6 +24,7 @@ layout (push_constant) uniform parameter { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; uint32_t nb01; uint32_t nb02; @@ -34,7 +35,6 @@ layout (push_constant) uniform parameter { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index da478be24..e74e2fa93 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -13,7 +13,9 @@ #include "types.comp" #include "flash_attn_base.comp" -const uint32_t D_per_thread = D / D_split; +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + const uint32_t row_split = 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; @@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - uint32_t offset = (iq2 + r) * D + c; + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } @@ -44,14 +46,14 @@ const uint32_t MatBc = 16; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; -const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; -// Avoid padding for D==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +// Avoid padding for hsk==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; shared ACC_TYPE sfsh[Bc * sfshstride]; -const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4 shared f16vec4 ksh[Bc * kshstride]; shared float slope[Br]; @@ -74,18 +76,18 @@ void main() { uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t r = (idx + tid) / (D / 4); - if (r < Br && d < D / 4 && + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && i * Br + r < N) { Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPEV4(0.0); } @@ -123,14 +125,18 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t c = (idx + tid) / (D / 4); - if (c < Bc && d < D / 4) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; @@ -145,14 +151,14 @@ void main() { } barrier(); - // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br - // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat SfMat = coopmat(0); coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < D / 16; ++d) { + for (uint32_t d = 0; d < HSK / 16; ++d) { coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; @@ -181,7 +187,7 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } barrier(); @@ -202,7 +208,7 @@ void main() { eMf[r] = exp(Moldf - Mf[r]); } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = float16_t(eMf[r]) * Of[r][d]; } @@ -217,7 +223,7 @@ void main() { Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); Lf[r] += Pf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -280,7 +286,7 @@ void main() { } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = float16_t(eMf[r]) * Of[r][d]; tmpshv4[tid] = Of[r][d]; @@ -300,11 +306,11 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); } @@ -312,7 +318,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -328,18 +334,18 @@ void main() { Lfrcp[r] = 1.0 / Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= float16_t(Lfrcp[r]); } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); } @@ -349,9 +355,9 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6acf67a03..8792d5195 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - if (r < N && c < D) { - uint32_t offset = (iq2 + r) * D + c; + if (r < N && c < HSV) { + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); } return elem; @@ -86,9 +86,9 @@ void main() { tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif - tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); - tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); - tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV); // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) @@ -104,16 +104,16 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; - coopmat Qf16; + coopmat Q; + coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; - coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK)); - Qf16 = coopmat(Q); + Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -130,15 +130,20 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1) { + m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { @@ -155,7 +160,7 @@ void main() { coopmat mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat(mv); } @@ -203,42 +208,42 @@ void main() { rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row - coopmat eMdiag; + coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); + coopmat PV = coopmat(0); PV = coopMatMulAdd(P_A, V, PV); - O = eMdiag * O + coopmat(PV); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; } - coopmat Ldiag; + coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); @@ -250,18 +255,18 @@ void main() { O = Ldiag*O; - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV); // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e395685..599cef072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; + uint ne3; uint k_num; } p; @@ -19,13 +20,14 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; uint D = p.D; uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * k_num + n; - uint m_offset = D * N * k_num + N + n; + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; uint lm_stride = N * 2; // Compute the max m value for the row @@ -49,11 +51,11 @@ void main() { for (uint d = tid; d < D; d += BLOCK_SIZE) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * k + D * n + d; + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } O *= L; - data_d[D * n + d] = O; + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 000000000..cbd4cb36b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.comp" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 000000000..3a2a6897b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 51fc2dc7e..5bcd3b1e3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -6,6 +6,14 @@ layout (push_constant) uniform parameter { uint KX; uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; float scale; float max_bias; float m0; @@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } if (rowx >= p.nrows_x) { return; @@ -41,7 +57,7 @@ void soft_max(uint num_iters) { // ALiBi if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index + const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; @@ -67,7 +83,7 @@ void soft_max(uint num_iters) { FLOAT_TYPE b = FLOAT_TYPE(0); if (p.KY > 0 && col < p.KX) { - b = data_b[rowy * p.KX + col]; + b = data_b[rowy_start + col]; } FLOAT_TYPE v = a * p.scale + slope * b; @@ -111,7 +127,7 @@ void soft_max(uint num_iters) { if (idx < DATA_CACHE_SIZE) { val = exp(data_cache[idx] - max_val); } else { - val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); } sum += val; if (idx < DATA_CACHE_SIZE) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b0a1344a5..6ab03827b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -607,6 +607,10 @@ void process_shaders() { string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fe1e04965..c8eabd645 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21,6 +21,9 @@ #include #endif +#define GGML_VERSION "0.0.1" +#define GGML_COMMIT "KCPP" + #include #include #include @@ -474,6 +477,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; } +const char * ggml_version(void) { + return GGML_VERSION; +} + +const char * ggml_commit(void) { + return GGML_COMMIT; +} + // // timing // @@ -1145,9 +1156,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", "SWIGLU", + "GEGLU_ERF", + "GEGLU_QUICK", }; -static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3"); +static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -2773,6 +2786,48 @@ struct ggml_tensor * ggml_swiglu_split( return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); } +// ggml_geglu_erf + +struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false); +} + +struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true); +} + +struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false); +} + +// ggml_geglu_quick + +struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false); +} + +struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true); +} + +struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false); +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( @@ -3679,9 +3734,10 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); + GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); + GGML_ASSERT(a->ne[3]%mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4702,13 +4758,17 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + GGML_ASSERT(q->ne[3] == k->ne[3]); + GGML_ASSERT(q->ne[3] == v->ne[3]); + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); - GGML_ASSERT(mask->ne[3] == 1); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + + GGML_ASSERT(q->ne[2] % mask->ne[2] == 0); + GGML_ASSERT(q->ne[3] % mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4836,7 +4896,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -4859,36 +4918,49 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors + GGML_ASSERT(A->ne[0] == d_state); + } } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->src[0] = s; @@ -4897,6 +4969,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = ids; return result; } @@ -6037,13 +6110,28 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_GLU: { + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_SWIGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split swiglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0)); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); + } + } break; + default: { + GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); + } //break; + } + } break; case GGML_OP_NONE: { // noop } break; case GGML_OP_COUNT: default: { - fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); - GGML_ABORT("fatal error"); + GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); } //break; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b5ba933cb..c12609c6d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -170,6 +170,7 @@ class Keys: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class WKV: @@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum): RWKV7 = auto() ARWKV7 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -429,6 +431,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() @@ -628,6 +631,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -730,6 +734,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", @@ -1714,6 +1719,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2497,6 +2515,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d32cd479a..a7ecf3d31 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -714,8 +714,8 @@ class GGUFWriter: def add_clamp_kqv(self, value: float) -> None: self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) - def add_shared_kv_layers(self, value: float) -> None: - self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) + def add_shared_kv_layers(self, value: int) -> None: + self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) @@ -861,6 +861,9 @@ class GGUFWriter: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b30f77dbe..51634ef6b 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -477,7 +477,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( @@ -574,6 +574,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 3f541b0c0..635fcef35 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -245,9 +245,18 @@ class SpecialVocab: if not tokenizer_config: return True chat_template_alt = None - chat_template_file = path / 'chat_template.json' - if chat_template_file.is_file(): - with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_json = path / 'chat_template.json' + chat_template_jinja = path / 'chat_template.jinja' + if chat_template_jinja.is_file(): + with open(chat_template_jinja, encoding = 'utf-8') as f: + chat_template_alt = f.read() + if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')): + chat_template_alt = [{'name': 'default', 'template': chat_template_alt}] + for template_path in additional_templates: + with open(template_path, encoding = 'utf-8') as fp: + chat_template_alt.append({'name': template_path.stem, 'template': fp.read()}) + elif chat_template_json.is_file(): + with open(chat_template_json, encoding = 'utf-8') as f: chat_template_alt = json.load(f).get('chat_template') chat_template = tokenizer_config.get('chat_template', chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): diff --git a/klite.embd b/klite.embd index e63802427..aefda50ea 100644 --- a/klite.embd +++ b/klite.embd @@ -3208,6 +3208,7 @@ Current version indicated by LITEVER below. var schedule_multiplayer_major_change = false; var last_request_str = "No Requests Available"; //full context of last submitted request var last_response_obj = null; + var last_response_streamlog = ""; var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp var kai_poll_recoverykey = ""; //for recovering a lost polled streaming in case of disconnect. var globalabortcontroller = null; @@ -5634,6 +5635,10 @@ Current version indicated by LITEVER below. }, transform(chunk, ctrl) { ctrl.buf += chunk; + if(chunk) + { + last_response_streamlog += chunk; + } let evs = []; let m; while ((m = /^data: ?(.*)(\r?\n){2}/m.exec(ctrl.buf)) !== null) { @@ -9343,6 +9348,10 @@ Current version indicated by LITEVER below. { lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj); } + if(last_response_streamlog) + { + lr += "\n\nResponse:\n" + last_response_streamlog; + } msgbox(lr,"Last Request Info",false); } function show_last_logprobs() @@ -10401,7 +10410,7 @@ Current version indicated by LITEVER below. desired_oai_ep = transform_oai_ep(desired_oai_ep); let oaiheaders = {}; - if(desired_oai_key!=""){ + if(desired_oai_key!="" && !desired_oai_ep.toLowerCase().includes("pollinations.ai")){ oaiheaders["Authorization"] = "Bearer " + desired_oai_key; }; if (desired_oai_ep.toLowerCase().includes("api.mistral.ai")) { @@ -14041,6 +14050,7 @@ Current version indicated by LITEVER below. redo_arr = []; last_request_str = "No Requests Available"; last_response_obj = null; + last_response_streamlog = ""; retry_prev_text = []; retry_preserve_last = false; retry_in_progress = false; @@ -16471,6 +16481,7 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(submit_payload); last_response_obj = null; + last_response_streamlog = ""; if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) { let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint); kobold_api_stream_sse(sub_endpt, submit_payload); @@ -16643,11 +16654,15 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(oai_payload); last_response_obj = null; + last_response_streamlog = ""; let oaiheaders = { - 'Content-Type': 'application/json', - 'Authorization': 'Bearer ' + custom_oai_key + 'Content-Type': 'application/json' }; + if (!targetep.toLowerCase().includes("pollinations.ai")) { + oaiheaders['Authorization'] = 'Bearer ' + custom_oai_key; + } + if(targetep.toLowerCase().includes("openrouter.ai")) { oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net"; @@ -16780,6 +16795,7 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(claude_payload); last_response_obj = null; + last_response_streamlog = ""; let claudeheaders = { 'Content-Type': 'application/json', @@ -16970,6 +16986,7 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(payload); last_response_obj = null; + last_response_streamlog = ""; let geminiheaders = { 'Content-Type': 'application/json' }; if(is_browser_supports_sse() && localsettings.tokenstreammode!=0) @@ -17018,6 +17035,7 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(cohere_payload); last_response_obj = null; + last_response_streamlog = ""; let cohere_headers = { 'Content-Type': 'application/json', 'Authorization': 'Bearer ' + custom_cohere_key @@ -17093,6 +17111,7 @@ Current version indicated by LITEVER below. last_request_str = JSON.stringify(submit_payload); last_response_obj = null; + last_response_streamlog = ""; fetch(horde_submit_endpoint, { method: 'POST', // or 'PUT' diff --git a/koboldcpp.py b/koboldcpp.py index c0b65aba9..b2241797c 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -62,7 +62,7 @@ dry_seq_break_max = 128 extra_images_max = 4 # global vars -KcppVersion = "1.95.1" +KcppVersion = "1.96" showdebug = True kcpp_instance = None #global running instance global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""} diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index aa21108a4..ab2405430 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -45,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -170,6 +171,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -1004,6 +1006,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1761,6 +1779,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1894,6 +1913,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { bool llm_arch_is_recurrent(const llm_arch & arch) { switch (arch) { case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: diff --git a/src/llama-arch.h b/src/llama-arch.h index 0771ec3eb..b769831df 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -49,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA3N, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -174,6 +175,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_WKV_HEAD_SIZE, @@ -293,6 +295,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 5676dedcb..b71041199 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( // note: tracking the other way around is not necessary for now //seq_cpl[s0][s1] = true; + + has_cpl = true; } } } @@ -404,6 +406,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -419,6 +425,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -443,6 +451,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -458,9 +467,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { return ubatch_add(idxs, idxs.size(), false); } -llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { +llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { + if (sequential && has_cpl) { + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + + return {}; + } + std::vector cur_seq_set; + llama_seq_id last_seq_id = -1; + // determine the non-overlapping sequence sets participating in this ubatch for (int32_t i = 0; i < batch.n_tokens; ++i) { if (used[i]) { @@ -477,9 +494,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } } + // accept only increasing sequence ids + if (sequential) { + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + } + if (add) { cur_seq_set.push_back(seq_set[i]); + last_seq_id = batch.seq_id[i][0]; + if (cur_seq_set.size() > n_ubatch) { break; } @@ -528,6 +552,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -569,6 +594,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c537618..3420803ff 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ public: uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -69,7 +70,8 @@ public: llama_ubatch split_simple(uint32_t n_ubatch); // make ubatches of equal-length sequences sets - llama_ubatch split_equal(uint32_t n_ubatch); + // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids + llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); // sequence-set-wise split - each ubatch contains a single sequence-set llama_ubatch split_seq(uint32_t n_ubatch); @@ -112,6 +114,9 @@ private: using pos_set_t = std::set; using seq_cpl_t = std::vector; + // helper flag to quickly determine if there are any coupled sequences in the batch + bool has_cpl; + std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -125,6 +130,8 @@ private: // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 010300df6..7f0e8c67f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->set_input_k_idxs(self_k_idxs, ubatch); + mctx->set_input_v_idxs(self_v_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - if (self_kq_mask_swa) { - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); - } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { @@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); const int64_t n_rs = mctx->get_recr()->get_n_rs(); @@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_one::set_input(const llama_ubatch *) { +void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); GGML_ASSERT(one && ggml_nelements(one) == 1); float f_one = 1.0f; ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); @@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto n_kv = inp->mctx->get_attn()->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_kq_mask, "KQ_mask", -1); + inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const auto n_kv = mctx_cur->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & k_idxs = inp->get_k_idxs(); + const auto & v_idxs = inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn( // optionally store to KV cache if (k_cur) { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } if (v_cur) { - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); @@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & k_idxs = inp->get_k_idxs(); + const auto & v_idxs = inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs( uint32_t kv_head, uint32_t kv_size, int32_t rs_zero, - bool avoid_copies) const { + const llm_graph_get_rows_fn & get_state_rows) const { ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); @@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - ggml_tensor * output_states; - - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); @@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { - const auto * mctx_cur = static_cast(mctx); + const llm_graph_get_rows_fn & get_state_rows) const { + const auto * kv_state = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } ggml_tensor * llm_graph_context::build_rs( @@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { - const auto * mctx_cur = static_cast(mctx)->get_recr(); + const llm_graph_get_rows_fn & get_state_rows) const { + const auto * kv_state = static_cast(mctx)->get_recr(); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index ceddb6021..7bdf65676 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -228,8 +228,8 @@ public: ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -249,10 +249,16 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -274,13 +280,23 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; } + ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -297,8 +313,8 @@ public: ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -319,10 +335,16 @@ public: ggml_tensor * s_copy; // I32 [kv_size] + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -336,7 +358,7 @@ public: llm_graph_input_one() {} virtual ~llm_graph_input_one() = default; - void set_input(const llama_ubatch *) override; + void set_input(const llama_ubatch * ubatch) override; ggml_tensor * one = nullptr; // F32 }; @@ -424,6 +446,9 @@ struct llm_graph_params { const llm_graph_cb & cb; }; +// used in build_rs to properly order writes and avoid unnecessary copies +using llm_graph_get_rows_fn = std::function; + struct llm_graph_context { const llm_arch arch; @@ -663,7 +688,7 @@ struct llm_graph_context { uint32_t kv_head, uint32_t kv_size, int32_t rs_zero, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; llm_graph_input_rs * build_rs_inp() const; @@ -673,7 +698,7 @@ struct llm_graph_context { ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rs( llm_graph_input_mem_hybrid * inp, @@ -681,7 +706,7 @@ struct llm_graph_context { ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index bba7a12dc..86c814d51 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // Corresponds to Mamba's conv_states size + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_s() const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index e85afe145..476d0a5ea 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -114,6 +114,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 2680d3b92..abf7f62c5 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -117,20 +117,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads_base = kv_base->prepare(ubatches); - if (heads_base.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - auto heads_swa = kv_swa->prepare(ubatches); - if (heads_swa.empty()) { + auto sinfos_base = kv_base->prepare(ubatches); + if (sinfos_base.empty()) { break; } - assert(heads_base.size() == heads_swa.size()); + auto sinfos_swa = kv_swa->prepare(ubatches); + if (sinfos_swa.empty()) { + break; + } + + assert(sinfos_base.size() == sinfos_swa.size()); return std::make_unique( - this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // if it fails, try equal split @@ -139,7 +144,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch); + auto ubatch = balloc.split_equal(n_ubatch, false); if (ubatch.n_tokens == 0) { break; @@ -148,20 +153,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads_base = kv_base->prepare(ubatches); - if (heads_base.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - auto heads_swa = kv_swa->prepare(ubatches); - if (heads_swa.empty()) { + auto sinfos_base = kv_base->prepare(ubatches); + if (sinfos_base.empty()) { break; } - assert(heads_base.size() == heads_swa.size()); + auto sinfos_swa = kv_swa->prepare(ubatches); + if (sinfos_swa.empty()) { + break; + } + + assert(sinfos_base.size() == sinfos_swa.size()); return std::make_unique( - this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies @@ -224,13 +234,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, - std::vector heads_base, - std::vector heads_swa, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), - ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), + ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), + ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 46c1ed614..23205d826 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -74,6 +74,8 @@ private: class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { public: + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + // used for errors llama_kv_cache_unified_iswa_context(llama_memory_status status); @@ -90,8 +92,8 @@ public: // used to create a batch processing context from a batch llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, - std::vector heads_base, - std::vector heads_swa, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, std::vector ubatches); virtual ~llama_kv_cache_unified_iswa_context(); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 23ca41473..5ed300b7a 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; + + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + + if (!supports_set_rows) { + LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); + } } void llama_kv_cache_unified::clear(bool data) { @@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads = prepare(ubatches); - if (heads.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos = prepare(ubatches); + if (sinfos.empty()) { break; } return std::make_unique( - this, std::move(heads), std::move(ubatches)); + this, std::move(sinfos), std::move(ubatches)); } while (false); return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); @@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct return std::make_unique(this, lctx, do_shift, std::move(dinfo)); } -llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { - llama_kv_cache_unified::ubatch_heads res; +llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::slot_info_vec_t res; struct state { uint32_t head_old; // old position of the head, before placing the ubatch - uint32_t head_new; // new position of the head, after placing the ubatch + + slot_info sinfo; // slot info for the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch }; @@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std:: bool success = true; for (const auto & ubatch : ubatches) { + // non-continuous slots require support for ggml_set_rows() + const bool cont = supports_set_rows ? false : true; + // only find a suitable slot for the ubatch. don't modify the cells yet - const int32_t head_new = find_slot(ubatch); - if (head_new < 0) { + const auto sinfo_new = find_slot(ubatch, cont); + if (sinfo_new.empty()) { success = false; break; } // remeber the position that we found - res.push_back(head_new); + res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); // now emplace the ubatch - apply_ubatch(head_new, ubatch); + apply_ubatch(sinfo_new, ubatch); } // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->head_new, it->cells); + cells.set(it->sinfo.idxs, it->cells); head = it->head_old; } @@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d return updated; } -int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { +llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { const uint32_t n_tokens = ubatch.n_tokens; uint32_t head_cur = this->head; @@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return -1; + return { }; } if (debug > 0) { @@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { uint32_t n_tested = 0; + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; + + slot_info res; + + auto & idxs = res.idxs; + + idxs.reserve(n_tokens); + while (true) { - if (head_cur + n_tokens > cells.size()) { + if (head_cur + n_test > cells.size()) { n_tested += cells.size() - head_cur; head_cur = 0; continue; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + //const llama_pos pos = ubatch.pos[i]; //const llama_seq_id seq_id = ubatch.seq_id[i][0]; @@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { // - (disabled) mask causally, if the sequence is the same as the one we are inserting // - mask SWA, using current max pos for that sequence in the cache // always insert in the cell with minimum pos - bool can_use = cells.is_empty(head_cur + i); + bool can_use = cells.is_empty(idx); - if (!can_use && cells.seq_count(head_cur + i) == 1) { - const llama_pos pos_cell = cells.pos_get(head_cur + i); + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); // (disabled) causal mask // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(head_cur + i, seq_id)) { + //if (cells.seq_has(idx, seq_id)) { // can_use = pos_cell >= pos; //} if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { @@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { } } - if (!can_use) { - found = false; - head_cur += i + 1; - n_tested += i + 1; + head_cur++; + n_tested++; + + if (can_use) { + idxs.push_back(idx); + } else { break; } } - if (found) { + if (idxs.size() == n_tokens) { break; } + if (cont) { + idxs.clear(); + } + if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return -1; + return { }; } } - return head_cur; + // we didn't find a suitable slot - return empty result + if (idxs.size() < n_tokens) { + res.clear(); + } + + return res; } -void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch seq_pos_max_rm[s] = -1; } - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - if (!cells.is_empty(head_cur + i)) { - assert(cells.seq_count(head_cur + i) == 1); + assert(ubatch.n_tokens == sinfo.idxs.size()); - const llama_seq_id seq_id = cells.seq_get(head_cur + i); - const llama_pos pos = cells.pos_get(head_cur + i); + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto idx = sinfo.idxs.at(i); + + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.rm(head_cur + i); + cells.rm(idx); } - cells.pos_set(head_cur + i, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); + cells.seq_add(idx, ubatch.seq_id[i][s]); } } @@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch } // move the head at the end of the slot - head = head_cur + ubatch.n_tokens; + head = sinfo.idxs.back() + 1; } bool llama_kv_cache_unified::get_can_shift() const { @@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint 0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; + const int64_t n_embd_k_gqa = k->ne[0]; const int64_t n_tokens = k_cur->ne[2]; + k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + + if (k_idxs && supports_set_rows) { + return ggml_set_rows(ctx, k, k_cur, k_idxs); + } + + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends + ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + n_tokens*n_embd_k_gqa, + ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); return ggml_cpy(ctx, k_cur, k_view); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const int64_t n_embd_v_gqa = v->ne[0]; const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + + if (v_idxs && supports_set_rows) { + if (!v_trans) { + return ggml_set_rows(ctx, v, v_cur, v_idxs); + } + + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + + // note: the V cache is transposed when not using flash attention + v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + + // note: we can be more explicit here at the cost of extra cont + // however, above we take advantage that a row of single element is always continuous regardless of the row stride + //v_cur = ggml_transpose(ctx, v_cur); + //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + + // we broadcast the KV indices n_embd_v_gqa times + // v [1, n_kv, n_embd_v_gqa] + // v_cur [1, n_tokens, n_embd_v_gqa] + // v_idxs [n_tokens, 1, 1] + return ggml_set_rows(ctx, v_view, v_cur, v_idxs); + } + + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends ggml_tensor * v_view = nullptr; if (!v_trans) { v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + n_tokens*n_embd_v_gqa, + ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head()); } else { - // note: the V cache is transposed when not using flash attention - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), - (v->ne[1])*ggml_element_size(v), - (head_cur)*ggml_element_size(v)); - v_cur = ggml_transpose(ctx, v_cur); + + v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, + (v->ne[1] )*ggml_element_size(v), + (sinfo.head())*ggml_element_size(v)); } return ggml_cpy(ctx, v_cur, v_view); } +ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + + ggml_set_input(k_idxs); + + return k_idxs; +} + +ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + + ggml_set_input(v_idxs); + + return v_idxs; +} + +void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { + if (!supports_set_rows) { + return; + } + + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + int64_t * data = (int64_t *) dst->data; + + for (int64_t i = 0; i < n_tokens; ++i) { + data[i] = sinfo.idxs.at(i); + } +} + +void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { + if (!supports_set_rows) { + return; + } + + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + int64_t * data = (int64_t *) dst->data; + + for (int64_t i = 0; i < n_tokens; ++i) { + data[i] = sinfo.idxs.at(i); + } +} + void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell ubatch.seq_id[i] = &dest_seq_id; } - const auto head_cur = find_slot(ubatch); - if (head_cur < 0) { + const auto sinfo = find_slot(ubatch, true); + if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - apply_ubatch(head_cur, ubatch); + apply_ubatch(sinfo, ubatch); + + const auto head_cur = sinfo.head(); // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; @@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); - head = 0; + + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph + sinfos.resize(1); + sinfos[0].idxs.resize(1); + sinfos[0].idxs[0] = 0; } llama_kv_cache_unified_context::llama_kv_cache_unified_context( @@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, - llama_kv_cache_unified::ubatch_heads heads, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { + llama_kv_cache_unified::slot_info_vec_t sinfos, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; bool llama_kv_cache_unified_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - if (++i_next >= ubatches.size()) { + if (++i_cur >= ubatches.size()) { return false; } @@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(heads[i_next], ubatches[i_next]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); n_kv = kv->get_n_kv(); - head = heads[i_next]; return true; } @@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return ubatches[i_next]; + return ubatches[i_cur]; } uint32_t llama_kv_cache_unified_context::get_n_kv() const { @@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t return kv->get_v(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - return kv->cpy_k(ctx, k_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { + return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - return kv->cpy_v(ctx, v_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { + return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]); +} + +ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + return kv->build_input_k_idxs(ctx, ubatch); +} + +ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + return kv->build_input_v_idxs(ctx, ubatch); } void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } +void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]); +} + +void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]); +} + void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 4c53f1273..b8b0356e8 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -24,8 +24,6 @@ public: // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; - using ubatch_heads = std::vector; - struct defrag_info { bool empty() const { return ids.empty(); @@ -37,6 +35,32 @@ public: std::vector ids; }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the + // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] + struct slot_info { + // data for ggml_set_rows + using idx_vec_t = std::vector; + + idx_vec_t idxs; + + uint32_t head() const { + return idxs.at(0); + } + + bool empty() const { + return idxs.empty(); + } + + void clear() { + idxs.clear(); + } + + // TODO: implement + //std::vector seq_idxs; + }; + + using slot_info_vec_t = std::vector; + llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -102,30 +126,37 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; // // preparation API // - // find places for the provided ubatches in the cache, returns the head locations + // find places for the provided ubatches in the cache, returns the slot infos // return empty vector on failure - ubatch_heads prepare(const std::vector & ubatches); + slot_info_vec_t prepare(const std::vector & ubatches); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); - // return the cell position where we can insert the ubatch - // return -1 on failure to find a contiguous slot of kv cells - int32_t find_slot(const llama_ubatch & ubatch) const; + // find a slot of kv cells that can hold the ubatch + // if cont == true, then the slot must be continuous + // return empty slot_info on failure + slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; - // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) - void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); // - // set_input API + // input API // + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -157,8 +188,13 @@ private: // SWA const uint32_t n_swa = 0; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + int supports_set_rows = false; + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; @@ -211,8 +247,8 @@ private: class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using ubatch_heads = llama_kv_cache_unified::ubatch_heads; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -231,7 +267,7 @@ public: // used to create a batch procesing context from a batch llama_kv_cache_unified_context( llama_kv_cache_unified * kv, - ubatch_heads heads, + slot_info_vec_t sinfos, std::vector ubatches); virtual ~llama_kv_cache_unified_context(); @@ -257,11 +293,16 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; - void set_input_k_shift(ggml_tensor * dst) const; + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + + void set_input_k_shift (ggml_tensor * dst) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -283,10 +324,10 @@ private: // batch processing context // - // the index of the next ubatch to process - size_t i_next = 0; + // the index of the cur ubatch to process + size_t i_cur = 0; - ubatch_heads heads; + slot_info_vec_t sinfos; std::vector ubatches; @@ -297,7 +338,4 @@ private: // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; - - // the beginning of the current slot in which the ubatch will be inserted - int32_t head; }; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index c95d63594..0d0dd316f 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -105,10 +105,30 @@ public: res.resize(n); for (uint32_t j = 0; j < n; ++j) { - res.pos[j] = pos[i + j]; - res.seq[j] = seq[i + j]; + const auto idx = i + j; - assert(shift[i + j] == 0); + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); + } + + return res; + } + + // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) + llama_kv_cells_unified cp(const std::vector & idxs) const { + llama_kv_cells_unified res; + + res.resize(idxs.size()); + + for (uint32_t j = 0; j < idxs.size(); ++j) { + const auto idx = idxs[j]; + + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); } return res; @@ -119,26 +139,58 @@ public: assert(i + other.pos.size() <= pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { - if (pos[i + j] == -1 && other.pos[j] != -1) { + const auto idx = i + j; + + if (pos[idx] == -1 && other.pos[j] != -1) { used.insert(i + j); } - if (pos[i + j] != -1 && other.pos[j] == -1) { + if (pos[idx] != -1 && other.pos[j] == -1) { used.erase(i + j); } - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_rm(i + j); } - pos[i + j] = other.pos[j]; - seq[i + j] = other.seq[j]; + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_add(i + j); } - assert(shift[i + j] == 0); + assert(shift[idx] == 0); + } + } + + // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) + void set(const std::vector & idxs, const llama_kv_cells_unified & other) { + assert(idxs.size() == other.pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + const auto idx = idxs[j]; + + if (pos[idx] == -1 && other.pos[j] != -1) { + used.insert(idx); + } + + if (pos[idx] != -1 && other.pos[j] == -1) { + used.erase(idx); + } + + if (pos[idx] != -1) { + seq_pos_rm(idx); + } + + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; + + if (pos[idx] != -1) { + seq_pos_add(idx); + } + + assert(shift[idx] == 0); } } diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 67cbf9554..6cd10db06 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) { @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? @@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, - std::vector heads_attn, + slot_info_vec_t sinfos_attn, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), + ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index f0c2420e9..4ac318175 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -92,6 +92,8 @@ private: class llama_memory_hybrid_context : public llama_memory_context_i { public: + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + // init failure explicit llama_memory_hybrid_context(llama_memory_status status); @@ -107,7 +109,7 @@ public: // init success llama_memory_hybrid_context( llama_memory_hybrid * mem, - std::vector heads_attn, + slot_info_vec_t sinfos_attn, std::vector ubatches); ~llama_memory_hybrid_context() = default; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4b3711207..b502befdc 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 366488813..8e3b755ba 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -213,23 +213,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV6: { @@ -1086,6 +1090,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3216,6 +3252,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } @@ -4727,10 +4811,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + } + + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); if (!classifier_labels.empty()) { @@ -9765,9 +9853,7 @@ struct llm_build_starcoder2 : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context { - const llama_model & model; - - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -9785,7 +9871,11 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + } else { + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -9819,11 +9909,11 @@ struct llm_build_mamba : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - // TODO: split ggml_tensor * build_mamba_layer( llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, + const llama_model & model, const llama_ubatch & ubatch, int il) const { const auto * mctx_cur = static_cast(mctx); @@ -9834,6 +9924,8 @@ struct llm_build_mamba : public llm_graph_context { const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9849,15 +9941,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // (ab)using the KV cache to store the states - ggml_tensor * conv = build_rs( - inp, gf, conv_states_all, - hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_rs( - inp, gf, ssm_states_all, - hparams.n_embd_s(), n_seqs); - ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9906,8 +9991,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9920,23 +10005,36 @@ struct llm_build_mamba : public llm_graph_context { dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + cur = x; + x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9945,7 +10043,136 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - //cb(cur, "mamba_out", il); + // cb(cur, "mamba_out", il); + + return cur; + } + + ggml_tensor * build_mamba2_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(mctx); + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // cb(cur, "mamba_out", il); return cur; } @@ -14768,6 +14995,7 @@ llm_graph_result_ptr llama_model::build_graph( llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: { llm = std::make_unique(*this, params, gf); } break; @@ -15028,6 +15256,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: diff --git a/src/llama-model.h b/src/llama-model.h index a958c5997..979fff620 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -172,6 +172,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; + struct ggml_tensor * ssm_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index afe83e128..aaf58edf4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1430,8 +1430,7 @@ struct clip_graph { ggml_tensor * x = embeddings; embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_silu_inplace(ctx0, embeddings); - embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_swiglu_split(ctx0, embeddings, x); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } // arrangement of BOI/EOI token embeddings @@ -1527,15 +1526,8 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); // swiglu - { - int64_t split_point = cur->ne[0] / 2; - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half - x1 = ggml_silu(ctx0, x1); - cur = ggml_mul(ctx0, x0, x1); - } + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + cur = ggml_swiglu_swapped(ctx0, cur); // mid-norm cur = ggml_rms_norm(ctx0, cur, 1e-6); @@ -1794,35 +1786,42 @@ private: cur = tmp; } + // we only support parallel ffn for now switch (type_op) { case FFN_SILU: - { + if (gate) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case FFN_GELU: - { + if (gate) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); } break; case FFN_GELU_ERF: - { + if (gate) { + cur = ggml_geglu_erf_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_erf", il); + } else { cur = ggml_gelu_erf(ctx0, cur); - cb(cur, "ggml_gelu_erf", il); + cb(cur, "ffn_gelu_erf", il); } break; case FFN_GELU_QUICK: - { + if (gate) { + cur = ggml_geglu_quick_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_quick", il); + } else { cur = ggml_gelu_quick(ctx0, cur); - cb(cur, "ffn_relu", il); + cb(cur, "ffn_gelu_quick", il); } break; } - // we only support parallel ffn for now - if (gate) { - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "ffn_gate_par", il); - } - if (down) { cur = ggml_mul_mat(ctx0, down, cur); }