mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/ISSUE_TEMPLATE/010-bug-compilation.yml # .github/ISSUE_TEMPLATE/011-bug-results.yml # .github/labeler.yml # .github/workflows/build.yml # .github/workflows/release.yml # .gitmodules # CMakeLists.txt # ggml/CMakeLists.txt # ggml/src/CMakeLists.txt # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/softmax_4_f16.cl # ggml/src/ggml-opencl/kernels/softmax_4_f32.cl # ggml/src/ggml-opencl/kernels/softmax_f16.cl # ggml/src/ggml-opencl/kernels/softmax_f32.cl # ggml/src/ggml-sycl/element_wise.cpp # ggml/src/ggml-sycl/element_wise.hpp # ggml/src/ggml-sycl/ggml-sycl.cpp # scripts/sync-ggml-am.sh # scripts/sync-ggml.last # scripts/sync-ggml.sh # tests/test-backend-ops.cpp # tests/test-c.c
This commit is contained in:
commit
57ce374240
64 changed files with 2944 additions and 979 deletions
|
@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model):
|
||||||
]
|
]
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
with open(self.dir_model / "chat_template.jinja") as f:
|
|
||||||
# quick hack to make sure chat template is added
|
|
||||||
self.gguf_writer.add_chat_template(f.read())
|
|
||||||
super().set_vocab()
|
super().set_vocab()
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
@ -4781,6 +4778,14 @@ class ARwkv7Model(Rwkv7Model):
|
||||||
class MambaModel(TextModel):
|
class MambaModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||||
|
|
||||||
|
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||||
|
# Avoid using AutoConfig for hparams
|
||||||
|
hparams = kwargs.pop("hparams", None)
|
||||||
|
if hparams is None:
|
||||||
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
|
hparams = json.load(f)
|
||||||
|
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
vocab_size = self.hparams["vocab_size"]
|
vocab_size = self.hparams["vocab_size"]
|
||||||
# Round vocab size to next multiple of 8
|
# Round vocab size to next multiple of 8
|
||||||
|
@ -4855,6 +4860,100 @@ class MambaModel(TextModel):
|
||||||
return [(new_name, data_torch)]
|
return [(new_name, data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Mamba2ForCausalLM")
|
||||||
|
class Mamba2Model(TextModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.MAMBA2
|
||||||
|
|
||||||
|
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||||
|
# Avoid using AutoConfig for hparams
|
||||||
|
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
|
||||||
|
hparams = kwargs.pop("hparams", None)
|
||||||
|
if hparams is None:
|
||||||
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
|
hparams = json.load(f)
|
||||||
|
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
vocab_size = self.hparams["vocab_size"]
|
||||||
|
# Round vocab size to next multiple of 16
|
||||||
|
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
|
||||||
|
# pad using ceiling division
|
||||||
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
|
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
|
||||||
|
self.hparams["vocab_size"] = vocab_size
|
||||||
|
|
||||||
|
if (self.dir_model / "tokenizer.model").is_file():
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
elif (self.dir_model / "tokenizer.model.v3").is_file():
|
||||||
|
# mamba-codestral
|
||||||
|
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
|
||||||
|
elif (self.dir_model / "tokenizer.json").is_file():
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
else:
|
||||||
|
# Use the GPT-NeoX tokenizer when no tokenizer files are present
|
||||||
|
self._set_vocab_builtin("gpt-neox", vocab_size)
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||||
|
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
|
||||||
|
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||||
|
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
|
||||||
|
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
|
||||||
|
n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||||
|
|
||||||
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||||
|
|
||||||
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
|
# TODO: does this really matter?
|
||||||
|
assert d_inner == 2 * d_model
|
||||||
|
assert d_inner % head_dim == 0
|
||||||
|
|
||||||
|
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||||
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
|
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||||
|
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||||
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||||
|
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||||
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
|
||||||
|
self.gguf_writer.add_ssm_group_count(n_group)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
|
||||||
|
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
|
||||||
|
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
|
||||||
|
name = name.removeprefix("model.")
|
||||||
|
|
||||||
|
if name.endswith(".dt_bias"):
|
||||||
|
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
|
||||||
|
data_torch = data_torch.squeeze()
|
||||||
|
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
|
||||||
|
gguf.MODEL_TENSOR.SSM_A,
|
||||||
|
gguf.MODEL_TENSOR.SSM_D,
|
||||||
|
]):
|
||||||
|
# unsqueeze A to use similar shape semantics as Mamba-1
|
||||||
|
# (D is also unsqueezed, but for more straightforward broadcast internally)
|
||||||
|
data_torch = data_torch.reshape((*data_torch.shape, 1))
|
||||||
|
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
|
||||||
|
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||||
|
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||||
|
n_group = self.hparams.get("n_groups", 1)
|
||||||
|
data_torch = data_torch.reshape((n_group, d_inner // n_group))
|
||||||
|
|
||||||
|
if name.endswith(".A_log"):
|
||||||
|
logger.debug("A_log --> A ==> " + new_name)
|
||||||
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
|
yield (new_name, data_torch)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("CohereForCausalLM")
|
@ModelBase.register("CohereForCausalLM")
|
||||||
class CommandR2Model(TextModel):
|
class CommandR2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||||
|
@ -6615,12 +6714,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
|
||||||
# maybe we should fallback to text model's arch in that case, since not many models have both
|
# maybe we should fallback to text model's arch in that case, since not many models have both
|
||||||
text_config = hparams.get("text_config", {})
|
text_config = hparams.get("text_config", {})
|
||||||
vision_config = hparams.get("vision_config", {})
|
vision_config = hparams.get("vision_config", {})
|
||||||
arch = hparams["architectures"][0]
|
arch = None
|
||||||
|
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
|
||||||
|
arch = arches[0]
|
||||||
|
elif "ssm_cfg" in hparams:
|
||||||
|
# For non-hf Mamba and Mamba2 models
|
||||||
|
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
|
||||||
|
|
||||||
# if "architectures" is found in the sub-config, use that instead
|
# if "architectures" is found in the sub-config, use that instead
|
||||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||||
arch = text_config["architectures"][0]
|
arch = text_config["architectures"][0]
|
||||||
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
|
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
|
||||||
arch = vision_config["architectures"][0]
|
arch = vision_config["architectures"][0]
|
||||||
|
if arch is None:
|
||||||
|
raise ValueError("Failed to detect model architecture")
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "ggml.h"
|
|
||||||
#include "ggml-backend.h"
|
|
||||||
|
|
||||||
#include <stdbool.h>
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define GGML_KOMPUTE_MAX_DEVICES 16
|
|
||||||
|
|
||||||
struct ggml_vk_device {
|
|
||||||
int index;
|
|
||||||
int type; // same as VkPhysicalDeviceType
|
|
||||||
size_t heapSize;
|
|
||||||
const char * name;
|
|
||||||
const char * vendor;
|
|
||||||
int subgroupSize;
|
|
||||||
uint64_t bufferAlignment;
|
|
||||||
uint64_t maxAlloc;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
|
|
||||||
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
|
|
||||||
bool ggml_vk_has_vulkan(void);
|
|
||||||
bool ggml_vk_has_device(void);
|
|
||||||
struct ggml_vk_device ggml_vk_current_device(void);
|
|
||||||
|
|
||||||
//
|
|
||||||
// backend API
|
|
||||||
//
|
|
||||||
|
|
||||||
// forward declaration
|
|
||||||
typedef struct ggml_backend * ggml_backend_t;
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
|
@ -563,6 +563,8 @@ extern "C" {
|
||||||
GGML_GLU_OP_REGLU,
|
GGML_GLU_OP_REGLU,
|
||||||
GGML_GLU_OP_GEGLU,
|
GGML_GLU_OP_GEGLU,
|
||||||
GGML_GLU_OP_SWIGLU,
|
GGML_GLU_OP_SWIGLU,
|
||||||
|
GGML_GLU_OP_GEGLU_ERF,
|
||||||
|
GGML_GLU_OP_GEGLU_QUICK,
|
||||||
|
|
||||||
GGML_GLU_OP_COUNT,
|
GGML_GLU_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
@ -659,6 +661,9 @@ extern "C" {
|
||||||
|
|
||||||
// misc
|
// misc
|
||||||
|
|
||||||
|
GGML_API const char * ggml_version(void);
|
||||||
|
GGML_API const char * ggml_commit(void);
|
||||||
|
|
||||||
GGML_API void ggml_time_init(void); // call this once at the beginning of the program
|
GGML_API void ggml_time_init(void); // call this once at the beginning of the program
|
||||||
GGML_API int64_t ggml_time_ms(void);
|
GGML_API int64_t ggml_time_ms(void);
|
||||||
GGML_API int64_t ggml_time_us(void);
|
GGML_API int64_t ggml_time_us(void);
|
||||||
|
@ -1157,6 +1162,22 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_erf(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_quick(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// A: n columns, r rows,
|
// A: n columns, r rows,
|
||||||
// B: n columns, r rows,
|
// B: n columns, r rows,
|
||||||
GGML_API struct ggml_tensor * ggml_glu_split(
|
GGML_API struct ggml_tensor * ggml_glu_split(
|
||||||
|
@ -1180,6 +1201,16 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// normalize along rows
|
// normalize along rows
|
||||||
GGML_API struct ggml_tensor * ggml_norm(
|
GGML_API struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1523,8 +1554,14 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// a [ne0, ne01, ne02, ne03]
|
||||||
|
// mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
|
||||||
|
//
|
||||||
|
// broadcast:
|
||||||
|
// ne02 % ne12 == 0
|
||||||
|
// ne03 % ne13 == 0
|
||||||
|
//
|
||||||
// fused soft_max(a*scale + mask*(ALiBi slope))
|
// fused soft_max(a*scale + mask*(ALiBi slope))
|
||||||
// mask is optional
|
|
||||||
// max_bias = 0.0f for no ALiBi
|
// max_bias = 0.0f for no ALiBi
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1987,11 +2024,17 @@ extern "C" {
|
||||||
|
|
||||||
#define GGML_KQ_MASK_PAD 64
|
#define GGML_KQ_MASK_PAD 64
|
||||||
|
|
||||||
// q: [n_embd_k, n_batch, n_head, 1]
|
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||||
// k: [n_embd_k, n_kv, n_head_kv, 1]
|
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
||||||
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
|
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
|
||||||
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||||
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
|
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
|
||||||
|
//
|
||||||
|
// broadcast:
|
||||||
|
// n_head % n_head_kv == 0
|
||||||
|
// n_head % ne32 == 0
|
||||||
|
// ne3 % ne33 == 0
|
||||||
|
//
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
@ -2030,7 +2073,8 @@ extern "C" {
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C);
|
struct ggml_tensor * C,
|
||||||
|
struct ggml_tensor * ids);
|
||||||
|
|
||||||
// partition into non-overlapping windows with padding if needed
|
// partition into non-overlapping windows with padding if needed
|
||||||
// example:
|
// example:
|
||||||
|
|
|
@ -61,10 +61,6 @@
|
||||||
#include "ggml-cann.h"
|
#include "ggml-cann.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_KOMPUTE
|
|
||||||
#include "ggml-kompute.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||||
#if defined(__clang__)
|
#if defined(__clang__)
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
|
@ -189,9 +185,6 @@ struct ggml_backend_registry {
|
||||||
#ifdef GGML_USE_RPC
|
#ifdef GGML_USE_RPC
|
||||||
register_backend(ggml_backend_rpc_reg());
|
register_backend(ggml_backend_rpc_reg());
|
||||||
#endif
|
#endif
|
||||||
#ifdef GGML_USE_KOMPUTE
|
|
||||||
register_backend(ggml_backend_kompute_reg());
|
|
||||||
#endif
|
|
||||||
#ifdef GGML_USE_CPU
|
#ifdef GGML_USE_CPU
|
||||||
register_backend(ggml_backend_cpu_reg());
|
register_backend(ggml_backend_cpu_reg());
|
||||||
#endif
|
#endif
|
||||||
|
@ -576,7 +569,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||||
ggml_backend_load_best("cann", silent, dir_path);
|
ggml_backend_load_best("cann", silent, dir_path);
|
||||||
ggml_backend_load_best("cuda", silent, dir_path);
|
ggml_backend_load_best("cuda", silent, dir_path);
|
||||||
ggml_backend_load_best("hip", silent, dir_path);
|
ggml_backend_load_best("hip", silent, dir_path);
|
||||||
ggml_backend_load_best("kompute", silent, dir_path);
|
|
||||||
ggml_backend_load_best("metal", silent, dir_path);
|
ggml_backend_load_best("metal", silent, dir_path);
|
||||||
ggml_backend_load_best("rpc", silent, dir_path);
|
ggml_backend_load_best("rpc", silent, dir_path);
|
||||||
ggml_backend_load_best("sycl", silent, dir_path);
|
ggml_backend_load_best("sycl", silent, dir_path);
|
||||||
|
|
|
@ -2186,6 +2186,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_geglu_erf
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_erf_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_erf_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||||
|
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
const float v = GGML_FP16_TO_FP32(x);
|
||||||
|
GGML_UNUSED(v);
|
||||||
|
assert(!isnan(v));
|
||||||
|
assert(!isinf(v));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_erf(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_erf_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_erf_f16(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_geglu_quick
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_quick_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_quick_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||||
|
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
const float v = GGML_FP16_TO_FP32(x);
|
||||||
|
GGML_UNUSED(v);
|
||||||
|
assert(!isnan(v));
|
||||||
|
assert(!isinf(v));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_quick(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_quick_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_quick_f16(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_norm
|
// ggml_compute_forward_norm
|
||||||
|
|
||||||
static void ggml_compute_forward_norm_f32(
|
static void ggml_compute_forward_norm_f32(
|
||||||
|
@ -5232,14 +5518,17 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||||
|
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||||
|
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||||
|
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||||
|
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||||
|
|
||||||
// TODO: is this supposed to be ceil instead of floor?
|
// TODO: is this supposed to be ceil instead of floor?
|
||||||
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||||
|
@ -5249,68 +5538,66 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||||
const int nr = ggml_nrows(src0);
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
// ALiBi
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
const uint32_t h = (i1/ne01)%ne02; // head
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
const int64_t i11 = i01;
|
||||||
|
const int64_t i12 = i02%ne12;
|
||||||
|
const int64_t i13 = i03%ne13;
|
||||||
|
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
// ALiBi
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
const uint32_t h = i02; // head
|
||||||
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||||
|
|
||||||
// broadcast the mask across rows
|
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
|
||||||
|
|
||||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
// broadcast the mask across rows
|
||||||
ggml_vec_scale_f32(nc, wp, scale);
|
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||||
if (mp_f32) {
|
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
|
||||||
if (use_f16) {
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
ggml_vec_cpy_f32 (ne00, wp, sp);
|
||||||
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
ggml_vec_scale_f32(ne00, wp, scale);
|
||||||
|
if (mp_f32) {
|
||||||
|
if (use_f16) {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp[i] += slope*mp_f32[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
#ifndef NDEBUG
|
||||||
wp[i] += slope*mp_f32[i];
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
|
assert(!isnan(wp[i]));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float max = -INFINITY;
|
||||||
|
ggml_vec_max_f32(ne00, &max, wp);
|
||||||
|
|
||||||
|
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
|
||||||
|
assert(sum > 0.0);
|
||||||
|
|
||||||
|
sum = 1.0/sum;
|
||||||
|
ggml_vec_scale_f32(ne00, dp, sum);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
assert(!isnan(dp[i]));
|
||||||
|
assert(!isinf(dp[i]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
|
||||||
//printf("p[%d] = %f\n", i, p[i]);
|
|
||||||
assert(!isnan(wp[i]));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
float max = -INFINITY;
|
|
||||||
ggml_vec_max_f32(nc, &max, wp);
|
|
||||||
|
|
||||||
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
|
|
||||||
assert(sum > 0.0);
|
|
||||||
|
|
||||||
sum = 1.0/sum;
|
|
||||||
ggml_vec_scale_f32(nc, dp, sum);
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
for (int i = 0; i < nc; ++i) {
|
|
||||||
assert(!isnan(dp[i]));
|
|
||||||
assert(!isinf(dp[i]));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7766,7 +8053,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
|
||||||
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
|
||||||
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
|
||||||
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
|
||||||
|
@ -7798,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
memset(VKQ32, 0, DV*sizeof(float));
|
memset(VKQ32, 0, DV*sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
||||||
|
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3 / rk3;
|
const int ik3 = iq3 / rk3;
|
||||||
|
@ -8336,120 +8623,210 @@ void ggml_compute_forward_ssm_conv(
|
||||||
static void ggml_compute_forward_ssm_scan_f32(
|
static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0]; // s
|
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
||||||
const ggml_tensor * src1 = dst->src[1]; // x
|
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
||||||
const ggml_tensor * src2 = dst->src[2]; // dt
|
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
||||||
const ggml_tensor * src3 = dst->src[3]; // A
|
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
||||||
const ggml_tensor * src4 = dst->src[4]; // B
|
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
||||||
const ggml_tensor * src5 = dst->src[5]; // C
|
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
||||||
|
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // dim
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
const int64_t nh = src1->ne[1]; // n_head
|
||||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
const int64_t ng = src4->ne[1];
|
||||||
|
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
||||||
|
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
||||||
|
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C
|
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
// allows optimizing the modulo since n_group should be a power of 2
|
||||||
// required for per-sequence offsets for states
|
GGML_ASSERT((ng & -ng) == ng);
|
||||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
|
||||||
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
||||||
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
|
||||||
|
|
||||||
// rows per thread
|
// heads per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dh = (nh + nth - 1)/nth;
|
||||||
|
|
||||||
// row range for this thread
|
// head range for this thread
|
||||||
const int ir0 = dr*ith;
|
const int ih0 = dh*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ih1 = MIN(ih0 + dh, nh);
|
||||||
const int ir = ir1 - ir0;
|
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_SVE
|
const int32_t * ids = (const int32_t *) src6->data;
|
||||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
|
||||||
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
||||||
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
||||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
||||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
||||||
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
||||||
|
|
||||||
// use the output as the source for the next token-wise iterations
|
for (int i3 = 0; i3 < ns; ++i3) {
|
||||||
if (i2 > 0) { s0 = s; }
|
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
||||||
|
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
||||||
|
|
||||||
// d_inner
|
for (int i2 = 0; i2 < nt; ++i2) {
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
||||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
||||||
float x_dt = x[i1] * dt_soft_plus;
|
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
||||||
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
||||||
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
||||||
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
for (int64_t k = 0; k < nc; k += svcntw()) {
|
if (src3->ne[0] == 1) {
|
||||||
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
||||||
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
|
||||||
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
|
||||||
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
|
||||||
|
|
||||||
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
// n_head
|
||||||
t1 = exp_ps_sve(svptrue_b32(), t1);
|
for (int h = ih0; h < ih1; ++h) {
|
||||||
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||||
|
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||||
|
const float dA = expf(dt_soft_plus * A[h]);
|
||||||
|
|
||||||
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
// dim
|
||||||
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
for (int i1 = 0; i1 < nr; ++i1) {
|
||||||
|
const int ii = i1 + h*nr;
|
||||||
|
const float x_dt = x[ii] * dt_soft_plus;
|
||||||
|
float sumf = 0.0f;
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
const int ggml_f32_epr = svcntw();
|
||||||
|
const int ggml_f32_step = 1 * ggml_f32_epr;
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
const int np = (nc & ~(ggml_f32_step - 1));
|
||||||
}
|
|
||||||
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
||||||
}
|
|
||||||
}
|
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
||||||
}
|
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += ggml_f32_step) {
|
||||||
|
// TODO: maybe unroll more?
|
||||||
|
for (int j = 0; j < 1; j++) {
|
||||||
|
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
||||||
|
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||||
|
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
||||||
|
|
||||||
|
t0 = GGML_F32_VEC_MUL(t0, adA);
|
||||||
|
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
||||||
|
|
||||||
|
t0 = GGML_F32_VEC_ADD(t0, t1);
|
||||||
|
|
||||||
|
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
||||||
|
|
||||||
|
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
||||||
#else
|
#else
|
||||||
for (int i3 = 0; i3 < n_s; ++i3) {
|
const int np = (nc & ~(GGML_F32_STEP - 1));
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
|
||||||
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
|
||||||
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
||||||
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
|
||||||
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
|
||||||
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
|
||||||
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
|
||||||
|
|
||||||
// use the output as the source for the next token-wise iterations
|
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||||
if (i2 > 0) { s0 = s; }
|
|
||||||
|
|
||||||
// d_inner
|
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
||||||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
|
||||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||||
float x_dt = x[i1] * dt_soft_plus;
|
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||||
float sumf = 0.0f;
|
GGML_F32_VEC az[GGML_F32_ARR];
|
||||||
// d_state
|
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||||
int i = i0 + i1*nc;
|
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||||
// state = prev_state * dA + dB * x
|
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
||||||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||||
// y = rowwise_dotprod(state, C)
|
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||||
sumf += state * C[i0];
|
|
||||||
s[i] = state;
|
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
||||||
|
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
||||||
|
|
||||||
|
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
||||||
|
|
||||||
|
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
||||||
|
|
||||||
|
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce sum0..sum3 to sum0
|
||||||
|
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
const int np = 0;
|
||||||
|
#endif
|
||||||
|
// d_state
|
||||||
|
for (int i0 = np; i0 < nc; ++i0) {
|
||||||
|
const int i = i0 + ii*nc;
|
||||||
|
const int ig = i0 + (h & (ng - 1))*nc;
|
||||||
|
// state = prev_state * dA + dB * x
|
||||||
|
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
||||||
|
// y = rowwise_dotprod(state, C)
|
||||||
|
sumf += state * C[ig];
|
||||||
|
s[i] = state;
|
||||||
|
}
|
||||||
|
y[ii] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Mamba-1 has an element-wise decay factor for the states
|
||||||
|
|
||||||
|
// n_head
|
||||||
|
for (int h = ih0; h < ih1; ++h) {
|
||||||
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
||||||
|
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||||
|
|
||||||
|
// dim
|
||||||
|
for (int i1 = 0; i1 < nr; ++i1) {
|
||||||
|
const int ii = i1 + h*nr;
|
||||||
|
const float x_dt = x[ii] * dt_soft_plus;
|
||||||
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
||||||
|
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
||||||
|
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
||||||
|
|
||||||
|
// d_state
|
||||||
|
// TODO: what happens when (d_state % svcntw()) != 0?
|
||||||
|
for (int64_t k = 0; k < nc; k += svcntw()) {
|
||||||
|
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
||||||
|
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
||||||
|
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
||||||
|
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
||||||
|
|
||||||
|
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
||||||
|
t1 = exp_ps_sve(svptrue_b32(), t1);
|
||||||
|
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
||||||
|
|
||||||
|
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
||||||
|
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
||||||
|
|
||||||
|
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
||||||
|
}
|
||||||
|
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
||||||
|
#else
|
||||||
|
float sumf = 0.0f;
|
||||||
|
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
||||||
|
// and also because expf is used within the loop.
|
||||||
|
// d_state
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
const int i = i0 + ii*nc;
|
||||||
|
const int ig = i0 + (h & (ng - 1))*nc;
|
||||||
|
// state = prev_state * dA + dB * x
|
||||||
|
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
||||||
|
// y = rowwise_dotprod(state, C)
|
||||||
|
sumf += state * C[ig];
|
||||||
|
s[i] = state;
|
||||||
|
}
|
||||||
|
y[ii] = sumf;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
y[i1] = sumf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// use the output as the source when it's not the first token-wise iteration
|
||||||
|
s0 = s;
|
||||||
}
|
}
|
||||||
#endif
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_ssm_scan(
|
void ggml_compute_forward_ssm_scan(
|
||||||
|
@ -8688,6 +9065,14 @@ void ggml_compute_forward_glu(
|
||||||
{
|
{
|
||||||
ggml_compute_forward_swiglu(params, dst);
|
ggml_compute_forward_swiglu(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_erf(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_quick(params, dst);
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
|
@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
||||||
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||||
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
||||||
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||||
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
|
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
||||||
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||||
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
||||||
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
||||||
|
|
|
@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||||
for (int i = 0; i < np; i += ggml_f32_step) {
|
for (int i = 0; i < np; i += ggml_f32_step) {
|
||||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||||
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
|
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
||||||
|
|
||||||
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
||||||
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
||||||
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
|
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
|
||||||
|
|
||||||
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
||||||
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
||||||
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
|
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
|
||||||
|
|
||||||
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
||||||
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
||||||
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
|
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
|
||||||
|
|
||||||
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
||||||
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
||||||
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
|
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
|
||||||
|
|
||||||
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
||||||
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
||||||
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
|
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
|
||||||
|
|
||||||
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
||||||
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
||||||
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
|
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
|
||||||
|
|
||||||
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
||||||
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
||||||
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
|
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
|
||||||
}
|
}
|
||||||
// leftovers
|
// leftovers
|
||||||
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
|
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
|
||||||
|
@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||||
for (int i = np; i < np2; i += ggml_f32_epr) {
|
for (int i = np; i < np2; i += ggml_f32_epr) {
|
||||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||||
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
|
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
||||||
}
|
}
|
||||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
||||||
if (np2 < n) {
|
if (np2 < n) {
|
||||||
|
|
|
@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||||
|
|
||||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||||
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
|
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i, ay1);
|
GGML_F32_VEC_STORE(y + i, ay1);
|
||||||
|
|
||||||
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
||||||
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
||||||
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
|
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
||||||
|
|
||||||
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
||||||
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
||||||
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
|
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
||||||
|
|
||||||
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
||||||
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
||||||
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
|
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
||||||
|
|
||||||
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
||||||
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
||||||
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
|
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
||||||
|
|
||||||
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
||||||
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
||||||
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
|
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
||||||
|
|
||||||
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
||||||
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
||||||
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
|
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
||||||
|
|
||||||
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
||||||
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
||||||
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
|
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
||||||
for (int i = np; i < np2; i += ggml_f32_epr) {
|
for (int i = np; i < np2; i += ggml_f32_epr) {
|
||||||
ax1 = GGML_F32_VEC_LOAD(x + i);
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
||||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||||
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
|
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
||||||
|
|
||||||
GGML_F32_VEC_STORE(y + i, ay1);
|
GGML_F32_VEC_STORE(y + i, ay1);
|
||||||
}
|
}
|
||||||
|
@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float xi = x[i];
|
||||||
|
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||||
|
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
||||||
|
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_GELU_QUICK_FP16
|
||||||
|
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
uint16_t t;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
|
||||||
|
memcpy(&t, &fp16, sizeof(uint16_t));
|
||||||
|
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
|
const uint16_t * i16 = (const uint16_t *) x;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float v = GGML_CPU_FP16_TO_FP32(g[i]);
|
||||||
|
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||||
#ifndef GGML_USE_ACCELERATE
|
#ifndef GGML_USE_ACCELERATE
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
|
|
|
@ -179,6 +179,20 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
|
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||||
|
do { \
|
||||||
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
||||||
|
const int id = ggml_cuda_get_device(); \
|
||||||
|
if (!shared_memory_limit_raised[id]) { \
|
||||||
|
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||||
|
shared_memory_limit_raised[id] = true; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
#else
|
||||||
|
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
||||||
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
||||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||||
|
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||||
} else {
|
} else {
|
||||||
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||||
|
@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
if (nbytes_shared <= smpbo) {
|
if (nbytes_shared <= smpbo) {
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||||
} else {
|
} else {
|
||||||
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||||
|
|
|
@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -851,7 +853,8 @@ void launch_fattn(
|
||||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||||
|
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
nb11, nb12, nb13,
|
nb11, nb12, nb13,
|
||||||
nb21, nb22, nb23,
|
nb21, nb22, nb23,
|
||||||
|
|
|
@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||||
|
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
|
||||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||||
|
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||||
|
|
||||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||||
|
@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_tile_ext_f16(
|
static __global__ void flash_attn_tile_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
|
@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||||
|
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_tile_ext_f32(
|
static __global__ void flash_attn_tile_ext_f32(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
|
@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||||
|
@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||||
|
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
K += nb12*(blockIdx.z / gqa_ratio);
|
K += nb12*(blockIdx.z / gqa_ratio);
|
||||||
V += nb22*(blockIdx.z / gqa_ratio);
|
V += nb22*(blockIdx.z / gqa_ratio);
|
||||||
|
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||||
|
|
||||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
|
@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||||
|
|
|
@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||||
|
@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
Q += nb02* blockIdx.z + nb01*ic0;
|
Q += nb02* blockIdx.z + nb01*ic0;
|
||||||
K += nb12*(blockIdx.z / gqa_ratio);
|
K += nb12*(blockIdx.z / gqa_ratio);
|
||||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic0;
|
|
||||||
|
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne12,
|
const int ne12,
|
||||||
const int ne13,
|
const int ne13,
|
||||||
const int ne31,
|
const int ne31,
|
||||||
|
const int ne32,
|
||||||
const int nb31,
|
const int nb31,
|
||||||
|
const int nb32,
|
||||||
const int nb01,
|
const int nb01,
|
||||||
const int nb02,
|
const int nb02,
|
||||||
const int nb03,
|
const int nb03,
|
||||||
|
@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
const half2 * mask2 = (const half2 *) maskh;
|
||||||
|
|
||||||
const int stride_Q = nb01 / sizeof(float);
|
const int stride_Q = nb01 / sizeof(float);
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
|
@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||||
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||||
|
|
|
@ -2319,6 +2319,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
ggml_cuda_op_swiglu(ctx, dst);
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
ggml_cuda_op_geglu_erf(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
ggml_cuda_op_geglu_quick(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -3121,6 +3127,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous_1(op->src[0]);
|
return ggml_is_contiguous_1(op->src[0]);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -3326,12 +3334,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_COS:
|
case GGML_OP_COS:
|
||||||
case GGML_OP_CLAMP:
|
case GGML_OP_CLAMP:
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
case GGML_OP_SSM_SCAN:
|
|
||||||
case GGML_OP_SSM_CONV:
|
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SSM_SCAN: {
|
||||||
|
if (op->src[3]->ne[0] == 1) {
|
||||||
|
// Mamba2
|
||||||
|
// (kernel only supports d_state == 128 && d_head % 16 == 0)
|
||||||
|
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
||||||
|
} else {
|
||||||
|
// Mamba
|
||||||
|
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
||||||
|
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case GGML_OP_SSM_CONV: {
|
||||||
|
// assumes d_inner % threads == 0
|
||||||
|
return op->src[0]->ne[1] % 128 == 0;
|
||||||
|
}
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
return op->src[0]->type != GGML_TYPE_BF16;
|
return op->src[0]->type != GGML_TYPE_BF16;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
return true;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_SOFT_MAX_BACK: {
|
case GGML_OP_SOFT_MAX_BACK: {
|
||||||
|
@ -3380,6 +3402,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
if (op->src[0]->ne[0] == 192) {
|
if (op->src[0]->ne[0] == 192) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TODO: support broadcast
|
||||||
|
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||||
|
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||||
if (op->src[0]->ne[3] != 1) {
|
if (op->src[0]->ne[3] != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3017,14 +3017,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||||
|
|
||||||
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
|
||||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
|
||||||
if (!shared_memory_limit_raised[id]) {
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
||||||
shared_memory_limit_raised[id] = true;
|
|
||||||
}
|
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
||||||
|
|
||||||
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
||||||
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "softmax.cuh"
|
#include "softmax.cuh"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ __forceinline__ float t2f32(T val) {
|
static __device__ __forceinline__ float t2f32(T val) {
|
||||||
|
@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||||
return __half2float(val);
|
return __half2float(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct soft_max_params {
|
||||||
|
|
||||||
|
int64_t nheads;
|
||||||
|
uint32_t n_head_log2;
|
||||||
|
int64_t ncols;
|
||||||
|
int64_t nrows_x;
|
||||||
|
int64_t nrows_y;
|
||||||
|
int64_t ne00;
|
||||||
|
int64_t ne01;
|
||||||
|
int64_t ne02;
|
||||||
|
int64_t ne03;
|
||||||
|
int64_t nb11;
|
||||||
|
int64_t nb12;
|
||||||
|
int64_t nb13;
|
||||||
|
|
||||||
|
int64_t ne12;
|
||||||
|
int64_t ne13;
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
};
|
||||||
|
|
||||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
||||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
|
@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||||
#endif // __clang__
|
#endif // __clang__
|
||||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(
|
static __global__ void soft_max_f32(
|
||||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
||||||
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int rowx = blockIdx.x;
|
|
||||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
const int64_t i03 = blockIdx.z;
|
||||||
|
const int64_t i02 = blockIdx.y;
|
||||||
|
const int64_t i01 = blockIdx.x;
|
||||||
|
|
||||||
|
//TODO: noncontigous inputs/outputs
|
||||||
|
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||||
|
|
||||||
|
const int64_t i11 = i01;
|
||||||
|
const int64_t i12 = i02 % p.ne12;
|
||||||
|
const int64_t i13 = i03 % p.ne13;
|
||||||
|
|
||||||
x += int64_t(rowx)*ncols;
|
x += int64_t(rowx)*ncols;
|
||||||
mask += int64_t(rowy)*ncols * (mask != nullptr);
|
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||||
dst += int64_t(rowx)*ncols;
|
dst += int64_t(rowx)*ncols;
|
||||||
|
|
||||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
|
@ -38,7 +70,7 @@ static __global__ void soft_max_f32(
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||||
|
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
|
@ -55,7 +87,7 @@ static __global__ void soft_max_f32(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
|
@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int... Ns, typename T>
|
||||||
|
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
||||||
|
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||||
|
{
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
|
auto launch_kernel = [=](auto I) -> bool {
|
||||||
|
constexpr int ncols = decltype(I)::value;
|
||||||
|
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||||
|
|
||||||
|
if (p.ncols == ncols) {
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||||
|
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(x, mask, dst, p);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// unary fold over launch_kernel
|
||||||
|
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//default case
|
||||||
|
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||||
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
|
const int64_t ncols_x = params.ncols;
|
||||||
|
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
|
||||||
const uint32_t n_head = nrows_x/nrows_y;
|
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const int id = ggml_cuda_get_device();
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
|
||||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
if (nbytes_shared <= smpbo) {
|
||||||
switch (ncols_x) {
|
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||||
case 32:
|
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 256:
|
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 512:
|
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 1024:
|
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 2048:
|
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
case 4096:
|
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src0->ne[1];
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||||
|
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||||
|
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||||
|
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||||
|
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||||
|
|
||||||
|
const uint32_t n_head = src0->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
|
||||||
|
soft_max_params params = {};
|
||||||
|
params.nheads = src0->ne[2];
|
||||||
|
params.n_head_log2 = n_head_log2;
|
||||||
|
params.ncols = ne00;
|
||||||
|
params.nrows_x = nrows_x;
|
||||||
|
params.nrows_y = nrows_y;
|
||||||
|
params.ne00 = src0->ne[0];
|
||||||
|
params.ne01 = src0->ne[1];
|
||||||
|
params.ne02 = src0->ne[2];
|
||||||
|
params.ne03 = src0->ne[3];
|
||||||
|
params.nb11 = nb11;
|
||||||
|
params.nb12 = nb12;
|
||||||
|
params.nb13 = nb13;
|
||||||
|
params.ne12 = ne12;
|
||||||
|
params.ne13 = ne13;
|
||||||
|
params.scale = scale;
|
||||||
|
params.max_bias = max_bias;
|
||||||
|
params.m0 = m0;
|
||||||
|
params.m1 = m1;
|
||||||
|
|
||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
|
||||||
__global__ void __launch_bounds__(splitD, 2)
|
__global__ void __launch_bounds__(splitD, 2)
|
||||||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||||
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
|
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||||
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||||
float * __restrict__ dst, const int64_t L) {
|
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||||
GGML_UNUSED(src1_nb0);
|
const int64_t s_off, const int64_t d_inner, const int64_t L) {
|
||||||
GGML_UNUSED(src2_nb0);
|
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
const int bidx = blockIdx.x; // split along B
|
const int bidx = blockIdx.x; // split along B (sequences)
|
||||||
const int bidy = blockIdx.y; // split along D
|
const int bidy = blockIdx.y; // split along D (d_inner)
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int wid = tid / 32;
|
const int wid = tid / 32;
|
||||||
const int wtid = tid % 32;
|
const int wtid = tid % 32;
|
||||||
|
@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||||
float * smem_A = smem;
|
float * smem_A = smem;
|
||||||
float * smem_s0 = smem_A + splitD * stride_sA;
|
float * smem_s0 = smem_A + splitD * stride_sA;
|
||||||
|
|
||||||
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
|
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
|
||||||
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
|
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
|
||||||
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
||||||
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
||||||
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
|
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
|
||||||
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
|
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
|
||||||
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
|
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
|
||||||
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
|
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
|
||||||
|
|
||||||
const int stride_s0 = src0_nb1 / sizeof(float);
|
const int stride_s0 = src0_nb2 / sizeof(float);
|
||||||
const int stride_x = src1_nb1 / sizeof(float);
|
const int stride_x = src1_nb2 / sizeof(float);
|
||||||
const int stride_dt = src2_nb1 / sizeof(float);
|
const int stride_dt = src2_nb1 / sizeof(float);
|
||||||
const int stride_A = src3_nb1 / sizeof(float);
|
const int stride_A = src3_nb1 / sizeof(float);
|
||||||
const int stride_B = src4_nb1 / sizeof(float);
|
const int stride_B = src4_nb2 / sizeof(float);
|
||||||
const int stride_C = src5_nb1 / sizeof(float);
|
const int stride_C = src5_nb2 / sizeof(float);
|
||||||
const int stride_s = stride_s0;
|
const int stride_s = stride_s0;
|
||||||
const int stride_y = stride_x;
|
const int stride_y = d_inner;
|
||||||
|
|
||||||
// can N not be 16? for example 32?
|
// can N not be 16? for example 32?
|
||||||
if (N == 16) {
|
if (N == 16) {
|
||||||
|
@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assumes as many threads as d_state
|
||||||
|
template <int splitH, int d_state>
|
||||||
|
__global__ void __launch_bounds__(d_state, 1)
|
||||||
|
ssm_scan_f32_group(
|
||||||
|
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||||
|
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||||
|
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||||
|
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||||
|
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||||
|
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||||
|
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
|
||||||
|
|
||||||
|
const int head_idx = (blockIdx.x * splitH) / d_head;
|
||||||
|
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
||||||
|
const int seq_idx = blockIdx.y;
|
||||||
|
|
||||||
|
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
|
||||||
|
|
||||||
|
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||||
|
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
||||||
|
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
||||||
|
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
||||||
|
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
||||||
|
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
||||||
|
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
|
||||||
|
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
||||||
|
|
||||||
|
// strides across n_seq_tokens
|
||||||
|
const int stride_x = src1_nb2 / sizeof(float);
|
||||||
|
const int stride_dt = src2_nb1 / sizeof(float);
|
||||||
|
const int stride_B = src4_nb2 / sizeof(float);
|
||||||
|
const int stride_C = src5_nb2 / sizeof(float);
|
||||||
|
const int stride_y = n_head * d_head;
|
||||||
|
|
||||||
|
float state[splitH];
|
||||||
|
// for the parallel accumulation
|
||||||
|
__shared__ float stateC[splitH * d_state];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < splitH; j++) {
|
||||||
|
state[j] = s0_block[j * d_state + threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tok; i++) {
|
||||||
|
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
|
||||||
|
// TODO: only calculate B and C once per head group
|
||||||
|
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
|
||||||
|
float dt_soft_plus = dt_block[i * stride_dt];
|
||||||
|
if (dt_soft_plus <= 20.0f) {
|
||||||
|
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
||||||
|
}
|
||||||
|
const float dA = expf(dt_soft_plus * A_block[0]);
|
||||||
|
const float B = B_block[i * stride_B + threadIdx.x];
|
||||||
|
const float C = C_block[i * stride_C + threadIdx.x];
|
||||||
|
|
||||||
|
// across d_head
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < splitH; j++) {
|
||||||
|
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
|
||||||
|
|
||||||
|
state[j] = (state[j] * dA) + (B * x_dt);
|
||||||
|
|
||||||
|
stateC[j * d_state + threadIdx.x] = state[j] * C;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// parallel accumulation for stateC
|
||||||
|
// TODO: simplify
|
||||||
|
{
|
||||||
|
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
|
||||||
|
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
|
||||||
|
|
||||||
|
// reduce until w matches the warp size
|
||||||
|
// TODO: does this work even when the physical warp size is 64?
|
||||||
|
#pragma unroll
|
||||||
|
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
|
||||||
|
// (assuming there are d_state threads)
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
|
||||||
|
// TODO: check for bank conflicts
|
||||||
|
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
|
||||||
|
stateC[k] += stateC[k + (w >> 1)];
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(splitH >= d_state / WARP_SIZE);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
|
||||||
|
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
|
||||||
|
y = warp_reduce_sum(y);
|
||||||
|
|
||||||
|
// store the above accumulations
|
||||||
|
if (threadIdx.x % WARP_SIZE == 0) {
|
||||||
|
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
|
||||||
|
y_block[i * stride_y + k] = y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// write back the state
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < splitH; j++) {
|
||||||
|
s_block[j * d_state + threadIdx.x] = state[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
||||||
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
|
const float * src4, const float * src5, const int32_t * src6, float * dst,
|
||||||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
|
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
|
||||||
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
|
||||||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||||
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
|
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
const int threads = 128;
|
const int threads = 128;
|
||||||
// todo: consider D cannot be divided,does this situation exist?
|
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||||
GGML_ASSERT(D % threads == 0);
|
if (src3_nb1 == sizeof(float)) {
|
||||||
const dim3 blocks(B, (D + threads - 1) / threads, 1);
|
// Mamba-2
|
||||||
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
|
if (d_state == 128) {
|
||||||
if (N == 16) {
|
GGML_ASSERT(d_state % threads == 0);
|
||||||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
// NOTE: can be any power of two between 4 and 64
|
||||||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
|
const int splitH = 16;
|
||||||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
|
GGML_ASSERT(head_dim % splitH == 0);
|
||||||
|
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
||||||
|
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
||||||
|
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("doesn't support d_state!=128.");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("doesn't support N!=16.");
|
// Mamba-1
|
||||||
|
GGML_ASSERT(n_head % threads == 0);
|
||||||
|
GGML_ASSERT(head_dim == 1);
|
||||||
|
GGML_ASSERT(n_group == 1);
|
||||||
|
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
||||||
|
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
||||||
|
if (d_state == 16) {
|
||||||
|
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
||||||
|
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||||
|
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||||
|
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("doesn't support d_state!=16.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||||
|
const struct ggml_tensor * src6 = dst->src[6]; // ids
|
||||||
// const int64_t d_state = src0->ne[0];
|
|
||||||
// const int64_t d_inner = src0->ne[1];
|
|
||||||
// const int64_t l = src1->ne[1];
|
|
||||||
// const int64_t b = src0->ne[2];
|
|
||||||
|
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // head_dim or 1
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
const int64_t nh = src1->ne[1]; // n_head
|
||||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
const int64_t ng = src4->ne[1]; // n_group
|
||||||
|
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
|
||||||
|
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C
|
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
|
||||||
// required for per-sequence offsets for states
|
|
||||||
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
|
|
||||||
// required to get correct offset for state destination (i.e. src1->nb[3])
|
|
||||||
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
|
|
||||||
|
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
const float * src1_d = (const float *) src1->data;
|
const float * src1_d = (const float *) src1->data;
|
||||||
|
@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const float * src3_d = (const float *) src3->data;
|
const float * src3_d = (const float *) src3->data;
|
||||||
const float * src4_d = (const float *) src4->data;
|
const float * src4_d = (const float *) src4->data;
|
||||||
const float * src5_d = (const float *) src5->data;
|
const float * src5_d = (const float *) src5->data;
|
||||||
|
const int32_t * src6_d = (const int32_t *) src6->data;
|
||||||
float * dst_d = (float *) dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
|
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
|
||||||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
|
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
|
||||||
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
|
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
|
||||||
|
s_off, nc, nr, nh, ng, n_t, n_s, stream);
|
||||||
}
|
}
|
||||||
|
|
|
@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
/* silu_back */
|
/* silu_back */
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||||
|
|
|
@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -229,7 +229,11 @@ typedef struct {
|
||||||
uint64_t nb21;
|
uint64_t nb21;
|
||||||
uint64_t nb22;
|
uint64_t nb22;
|
||||||
uint64_t nb23;
|
uint64_t nb23;
|
||||||
|
int32_t ne32;
|
||||||
|
int32_t ne33;
|
||||||
uint64_t nb31;
|
uint64_t nb31;
|
||||||
|
uint64_t nb32;
|
||||||
|
uint64_t nb33;
|
||||||
int32_t ne1;
|
int32_t ne1;
|
||||||
int32_t ne2;
|
int32_t ne2;
|
||||||
float scale;
|
float scale;
|
||||||
|
@ -461,9 +465,21 @@ typedef struct {
|
||||||
} ggml_metal_kargs_sum_rows;
|
} ggml_metal_kargs_sum_rows;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int32_t ne00;
|
||||||
int64_t ne01;
|
int32_t ne01;
|
||||||
int64_t ne02;
|
int32_t ne02;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne12;
|
||||||
|
int32_t ne13;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
uint64_t nb13;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
uint64_t nb3;
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
float m0;
|
float m0;
|
||||||
|
@ -499,26 +515,25 @@ typedef struct {
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t d_state;
|
int64_t d_state;
|
||||||
int64_t d_inner;
|
int64_t d_inner;
|
||||||
|
int64_t n_head;
|
||||||
|
int64_t n_group;
|
||||||
int64_t n_seq_tokens;
|
int64_t n_seq_tokens;
|
||||||
int64_t n_seqs;
|
int64_t n_seqs;
|
||||||
uint64_t nb00;
|
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
uint64_t nb02;
|
uint64_t nb02;
|
||||||
uint64_t nb10;
|
uint64_t nb03;
|
||||||
uint64_t nb11;
|
uint64_t nb11;
|
||||||
uint64_t nb12;
|
uint64_t nb12;
|
||||||
uint64_t nb13;
|
uint64_t nb13;
|
||||||
uint64_t nb20;
|
|
||||||
uint64_t nb21;
|
uint64_t nb21;
|
||||||
uint64_t nb22;
|
uint64_t nb22;
|
||||||
uint64_t nb30;
|
|
||||||
uint64_t nb31;
|
uint64_t nb31;
|
||||||
uint64_t nb40;
|
|
||||||
uint64_t nb41;
|
uint64_t nb41;
|
||||||
uint64_t nb42;
|
uint64_t nb42;
|
||||||
uint64_t nb50;
|
uint64_t nb43;
|
||||||
uint64_t nb51;
|
uint64_t nb51;
|
||||||
uint64_t nb52;
|
uint64_t nb52;
|
||||||
|
uint64_t nb53;
|
||||||
} ggml_metal_kargs_ssm_scan;
|
} ggml_metal_kargs_ssm_scan;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
|
|
@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_NORM,
|
GGML_METAL_KERNEL_TYPE_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
||||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
|
@ -529,6 +530,8 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_REGLU,
|
GGML_METAL_KERNEL_TYPE_REGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_GEGLU,
|
GGML_METAL_KERNEL_TYPE_GEGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
|
@ -1196,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
||||||
|
@ -1508,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
|
@ -1691,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -1725,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||||
|
@ -2454,6 +2462,12 @@ static bool ggml_metal_encode_node(
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
break;
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
@ -2644,10 +2658,7 @@ static bool ggml_metal_encode_node(
|
||||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const uint32_t n_head = src0->ne[2];
|
||||||
const int64_t nrows_y = src0->ne[1];
|
|
||||||
|
|
||||||
const uint32_t n_head = nrows_x/nrows_y;
|
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
@ -2707,6 +2718,18 @@ static bool ggml_metal_encode_node(
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne01,
|
||||||
/*.ne02 =*/ ne02,
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne11 =*/ ne11,
|
||||||
|
/*.ne12 =*/ ne12,
|
||||||
|
/*.ne13 =*/ ne13,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb12 =*/ nb12,
|
||||||
|
/*.nb13 =*/ nb13,
|
||||||
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.nb2 =*/ nb2,
|
||||||
|
/*.nb3 =*/ nb3,
|
||||||
/*.scale =*/ scale,
|
/*.scale =*/ scale,
|
||||||
/*.max_bias =*/ max_bias,
|
/*.max_bias =*/ max_bias,
|
||||||
/*.m0 =*/ m0,
|
/*.m0 =*/ m0,
|
||||||
|
@ -2726,7 +2749,7 @@ static bool ggml_metal_encode_node(
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
@ -2800,71 +2823,91 @@ static bool ggml_metal_encode_node(
|
||||||
struct ggml_tensor * src3 = node->src[3];
|
struct ggml_tensor * src3 = node->src[3];
|
||||||
struct ggml_tensor * src4 = node->src[4];
|
struct ggml_tensor * src4 = node->src[4];
|
||||||
struct ggml_tensor * src5 = node->src[5];
|
struct ggml_tensor * src5 = node->src[5];
|
||||||
|
struct ggml_tensor * src6 = node->src[6];
|
||||||
|
|
||||||
GGML_ASSERT(src3);
|
GGML_ASSERT(src3);
|
||||||
GGML_ASSERT(src4);
|
GGML_ASSERT(src4);
|
||||||
GGML_ASSERT(src5);
|
GGML_ASSERT(src5);
|
||||||
|
GGML_ASSERT(src6);
|
||||||
|
|
||||||
size_t offs_src3 = 0;
|
size_t offs_src3 = 0;
|
||||||
size_t offs_src4 = 0;
|
size_t offs_src4 = 0;
|
||||||
size_t offs_src5 = 0;
|
size_t offs_src5 = 0;
|
||||||
|
size_t offs_src6 = 0;
|
||||||
|
|
||||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||||
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||||
|
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
||||||
|
|
||||||
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
const int64_t ne30 = src3->ne[0];
|
||||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||||
|
|
||||||
const uint64_t nb30 = src3->nb[0];
|
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
||||||
const uint64_t nb31 = src3->nb[1];
|
const uint64_t nb31 = src3->nb[1];
|
||||||
|
|
||||||
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||||
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
const int64_t ne41 = src4->ne[1];
|
||||||
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||||
|
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
||||||
|
|
||||||
const uint64_t nb40 = src4->nb[0];
|
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
||||||
const uint64_t nb41 = src4->nb[1];
|
const uint64_t nb41 = src4->nb[1];
|
||||||
const uint64_t nb42 = src4->nb[2];
|
const uint64_t nb42 = src4->nb[2];
|
||||||
|
const uint64_t nb43 = src4->nb[3];
|
||||||
|
|
||||||
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
||||||
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
||||||
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||||
|
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
||||||
|
|
||||||
const uint64_t nb50 = src5->nb[0];
|
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
||||||
const uint64_t nb51 = src5->nb[1];
|
const uint64_t nb51 = src5->nb[1];
|
||||||
const uint64_t nb52 = src5->nb[2];
|
const uint64_t nb52 = src5->nb[2];
|
||||||
|
const uint64_t nb53 = src5->nb[3];
|
||||||
|
|
||||||
|
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
||||||
|
|
||||||
|
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
||||||
|
|
||||||
const int64_t d_state = ne00;
|
const int64_t d_state = ne00;
|
||||||
const int64_t d_inner = ne01;
|
const int64_t d_inner = ne01;
|
||||||
const int64_t n_seq_tokens = ne11;
|
const int64_t n_head = ne02;
|
||||||
const int64_t n_seqs = ne02;
|
const int64_t n_group = ne41;
|
||||||
|
const int64_t n_seq_tokens = ne12;
|
||||||
|
const int64_t n_seqs = ne13;
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
if (ne30 == 1) {
|
||||||
|
// Mamba-2
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
||||||
|
} else {
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_kargs_ssm_scan args = {
|
ggml_metal_kargs_ssm_scan args = {
|
||||||
/*.d_state =*/ d_state,
|
/*.d_state =*/ d_state,
|
||||||
/*.d_inner =*/ d_inner,
|
/*.d_inner =*/ d_inner,
|
||||||
|
/*.n_head =*/ n_head,
|
||||||
|
/*.n_group =*/ n_group,
|
||||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
/*.nb00 =*/ nb00,
|
/*.nb01 =*/ nb01,
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb02 =*/ nb02,
|
||||||
/*.nb02 =*/ nb02,
|
/*.nb03 =*/ nb03,
|
||||||
/*.nb10 =*/ nb10,
|
/*.nb11 =*/ nb11,
|
||||||
/*.nb11 =*/ nb11,
|
/*.nb12 =*/ nb12,
|
||||||
/*.nb12 =*/ nb12,
|
/*.nb13 =*/ nb13,
|
||||||
/*.nb13 =*/ nb13,
|
/*.nb21 =*/ nb21,
|
||||||
/*.nb20 =*/ nb20,
|
/*.nb22 =*/ nb22,
|
||||||
/*.nb21 =*/ nb21,
|
/*.nb31 =*/ nb31,
|
||||||
/*.nb22 =*/ nb22,
|
/*.nb41 =*/ nb41,
|
||||||
/*.nb30 =*/ nb30,
|
/*.nb42 =*/ nb42,
|
||||||
/*.nb31 =*/ nb31,
|
/*.nb43 =*/ nb43,
|
||||||
/*.nb40 =*/ nb40,
|
/*.nb51 =*/ nb51,
|
||||||
/*.nb41 =*/ nb41,
|
/*.nb52 =*/ nb52,
|
||||||
/*.nb42 =*/ nb42,
|
/*.nb53 =*/ nb53,
|
||||||
/*.nb50 =*/ nb50,
|
|
||||||
/*.nb51 =*/ nb51,
|
|
||||||
/*.nb52 =*/ nb52,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
@ -2874,10 +2917,17 @@ static bool ggml_metal_encode_node(
|
||||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:7];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
if (ne30 == 1) {
|
||||||
|
// Mamba-2
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(d_inner == 1);
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
{
|
{
|
||||||
|
@ -4979,7 +5029,11 @@ static bool ggml_metal_encode_node(
|
||||||
/*.nb21 =*/ nb21,
|
/*.nb21 =*/ nb21,
|
||||||
/*.nb22 =*/ nb22,
|
/*.nb22 =*/ nb22,
|
||||||
/*.nb23 =*/ nb23,
|
/*.nb23 =*/ nb23,
|
||||||
|
/*.ne32 =*/ ne32,
|
||||||
|
/*.ne33 =*/ ne33,
|
||||||
/*.nb31 =*/ nb31,
|
/*.nb31 =*/ nb31,
|
||||||
|
/*.nb32 =*/ nb32,
|
||||||
|
/*.nb33 =*/ nb33,
|
||||||
/*.ne1 =*/ ne1,
|
/*.ne1 =*/ ne1,
|
||||||
/*.ne2 =*/ ne2,
|
/*.ne2 =*/ ne2,
|
||||||
/*.scale =*/ scale,
|
/*.scale =*/ scale,
|
||||||
|
|
|
@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||||
|
#pragma METAL fp math_mode(safe)
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
float max = 0.0f;
|
float max = 0.0f;
|
||||||
|
|
||||||
|
@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
||||||
|
#pragma METAL fp math_mode(safe)
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
float max = 0.0f;
|
float max = 0.0f;
|
||||||
|
|
||||||
|
@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||||
|
#pragma METAL fp math_mode(safe)
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; j++) {
|
for (int j = 0; j < QK8_0; j++) {
|
||||||
|
@ -1258,6 +1261,50 @@ kernel void kernel_swiglu(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu_erf(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_erf*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu_quick(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_quick*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
constant ggml_metal_kargs_sum_rows & args,
|
constant ggml_metal_kargs_sum_rows & args,
|
||||||
|
@ -1320,24 +1367,28 @@ kernel void kernel_soft_max(
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_soft_max & args,
|
constant ggml_metal_kargs_soft_max & args,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint3 tptg[[threads_per_threadgroup]]) {
|
||||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
const int32_t i03 = tgpig.z;
|
||||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
const int32_t i02 = tgpig.y;
|
||||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
const int32_t i01 = tgpig.x;
|
||||||
|
|
||||||
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
const int32_t i13 = i03%args.ne13;
|
||||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
const int32_t i12 = i02%args.ne12;
|
||||||
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
const int32_t i11 = i01;
|
||||||
|
|
||||||
|
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||||
|
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||||
|
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (args.max_bias > 0.0f) {
|
if (args.max_bias > 0.0f) {
|
||||||
const int64_t h = i02;
|
const int32_t h = i02;
|
||||||
|
|
||||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||||
|
@ -1348,13 +1399,13 @@ kernel void kernel_soft_max(
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
float max_val = simd_max(lmax);
|
float max_val = simd_max(lmax);
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (tptg.x > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = -INFINITY;
|
buf[tiisg] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -1373,7 +1424,7 @@ kernel void kernel_soft_max(
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||||
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
|
@ -1385,7 +1436,7 @@ kernel void kernel_soft_max(
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (tptg.x > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = 0.0f;
|
buf[tiisg] = 0.0f;
|
||||||
}
|
}
|
||||||
|
@ -1404,7 +1455,7 @@ kernel void kernel_soft_max(
|
||||||
|
|
||||||
const float inv_sum = 1.0f/sum;
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||||
pdst[i00] *= inv_sum;
|
pdst[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1416,23 +1467,27 @@ kernel void kernel_soft_max_4(
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_soft_max & args,
|
constant ggml_metal_kargs_soft_max & args,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint3 tptg[[threads_per_threadgroup]]) {
|
||||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
const int32_t i03 = tgpig.z;
|
||||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
const int32_t i02 = tgpig.y;
|
||||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
const int32_t i01 = tgpig.x;
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
const int32_t i13 = i03%args.ne13;
|
||||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
const int32_t i12 = i02%args.ne12;
|
||||||
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
const int32_t i11 = i01;
|
||||||
|
|
||||||
|
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||||
|
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||||
|
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
if (args.max_bias > 0.0f) {
|
if (args.max_bias > 0.0f) {
|
||||||
const int64_t h = i02;
|
const int32_t h = i02;
|
||||||
|
|
||||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||||
|
@ -1443,14 +1498,14 @@ kernel void kernel_soft_max_4(
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
|
||||||
float max_val = simd_max(lmax);
|
float max_val = simd_max(lmax);
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (tptg.x > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = -INFINITY;
|
buf[tiisg] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
@ -1469,7 +1524,7 @@ kernel void kernel_soft_max_4(
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
|
@ -1483,7 +1538,7 @@ kernel void kernel_soft_max_4(
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (tptg.x > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = 0.0f;
|
buf[tiisg] = 0.0f;
|
||||||
}
|
}
|
||||||
|
@ -1502,7 +1557,7 @@ kernel void kernel_soft_max_4(
|
||||||
|
|
||||||
const float inv_sum = 1.0f/sum;
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||||
pdst4[i00] *= inv_sum;
|
pdst4[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1588,7 +1643,7 @@ kernel void kernel_ssm_conv_f32(
|
||||||
x[0] = sumf;
|
x[0] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
||||||
kernel void kernel_ssm_scan_f32(
|
kernel void kernel_ssm_scan_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const void * src1,
|
device const void * src1,
|
||||||
|
@ -1596,46 +1651,119 @@ kernel void kernel_ssm_scan_f32(
|
||||||
device const void * src3,
|
device const void * src3,
|
||||||
device const void * src4,
|
device const void * src4,
|
||||||
device const void * src5,
|
device const void * src5,
|
||||||
|
device const void * src6,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_ssm_scan & args,
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
const int64_t ir = tgpig.x;
|
const int64_t i1 = 0;
|
||||||
const int64_t i3 = tgpig.y;
|
const int64_t ir = tgpig.x; // current head
|
||||||
|
const int64_t i3 = tgpig.y; // current seq
|
||||||
|
|
||||||
|
const uint64_t nb00 = sizeof(float);
|
||||||
|
const uint64_t nb10 = sizeof(float);
|
||||||
|
const uint64_t nb20 = sizeof(float);
|
||||||
|
|
||||||
const int64_t nc = args.d_state;
|
const int64_t nc = args.d_state;
|
||||||
// const int64_t nr = args.d_inner;
|
const int64_t nr = args.d_inner;
|
||||||
|
const int64_t nh = args.n_head;
|
||||||
|
const int64_t ng = args.n_group;
|
||||||
const int64_t n_t = args.n_seq_tokens;
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
// const int64_t n_s = args.n_seqs;
|
|
||||||
|
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||||
|
|
||||||
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
|
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||||
|
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
|
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||||
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
||||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||||
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||||
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
|
|
||||||
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
|
|
||||||
|
|
||||||
if (i2 > 0) {
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
s0 = s;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
}
|
|
||||||
|
|
||||||
// i1 == 0
|
|
||||||
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
||||||
float x_dt = x[0] * dt_soft_plus;
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
int64_t i = i0;
|
const int64_t i = i0 + i1*nc;
|
||||||
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
||||||
sumf += state * C[i0];
|
sumf += state * C[i0];
|
||||||
s[i] = state;
|
s[i] = state;
|
||||||
}
|
}
|
||||||
|
|
||||||
y[0] = sumf;
|
y[0] = sumf;
|
||||||
|
|
||||||
|
// recurse
|
||||||
|
s0 = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||||
|
// TODO: optimize (e.g. by parallelizing over d_state)
|
||||||
|
kernel void kernel_ssm_scan_f32_group(
|
||||||
|
device const void * src0,
|
||||||
|
device const void * src1,
|
||||||
|
device const void * src2,
|
||||||
|
device const void * src3,
|
||||||
|
device const void * src4,
|
||||||
|
device const void * src5,
|
||||||
|
device const void * src6,
|
||||||
|
device float * dst,
|
||||||
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i1 = tgpig.x;
|
||||||
|
const int64_t ir = tgpig.y; // current head
|
||||||
|
const int64_t i3 = tgpig.z; // current seq
|
||||||
|
|
||||||
|
const uint64_t nb00 = sizeof(float);
|
||||||
|
const uint64_t nb10 = sizeof(float);
|
||||||
|
const uint64_t nb20 = sizeof(float);
|
||||||
|
|
||||||
|
const int64_t nc = args.d_state;
|
||||||
|
const int64_t nr = args.d_inner;
|
||||||
|
const int64_t nh = args.n_head;
|
||||||
|
const int64_t ng = args.n_group;
|
||||||
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
|
|
||||||
|
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
||||||
|
|
||||||
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
|
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||||
|
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||||
|
|
||||||
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||||
|
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
||||||
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
||||||
|
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||||
|
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||||
|
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
||||||
|
|
||||||
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
|
const float dA = exp(dt_soft_plus * A[0]);
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||||
|
const int64_t i = i0 + i1*nc;
|
||||||
|
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
||||||
|
sumf += state * C[i0];
|
||||||
|
s[i] = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[0] = sumf;
|
||||||
|
|
||||||
|
// recurse
|
||||||
|
s0 = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3776,7 +3904,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
// load the mask in shared memory
|
// load the mask in shared memory
|
||||||
#pragma unroll(Q)
|
#pragma unroll(Q)
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||||
|
|
||||||
const float m = pm[ic + tiisg];
|
const float m = pm[ic + tiisg];
|
||||||
|
|
||||||
|
@ -4262,7 +4390,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
const bool has_mask = mask != q;
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||||
|
|
||||||
#define GELU_COEF_A 0.044715f
|
#define GELU_COEF_A 0.044715f
|
||||||
|
#define GELU_QUICK_COEF -1.702f
|
||||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||||
|
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
// geglu
|
// geglu
|
||||||
|
@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
|
||||||
dst_row[i0] = silu*x1;
|
dst_row[i0] = silu*x1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// geglu_erf
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
kernel void kernel_geglu_erf(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb11,
|
||||||
|
int ne0,
|
||||||
|
ulong nb1,
|
||||||
|
int ne00_off,
|
||||||
|
int ne10_off
|
||||||
|
) {
|
||||||
|
src0 = (global char*)((global char*)src0 + offset0);
|
||||||
|
src1 = (global char*)((global char*)src1 + offset1);
|
||||||
|
dst = (global char*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||||
|
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||||
|
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||||
|
|
||||||
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_erf*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu_erf_f16(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb11,
|
||||||
|
int ne0,
|
||||||
|
ulong nb1,
|
||||||
|
int ne00_off,
|
||||||
|
int ne10_off
|
||||||
|
) {
|
||||||
|
src0 = (global char*)((global char*)src0 + offset0);
|
||||||
|
src1 = (global char*)((global char*)src1 + offset1);
|
||||||
|
dst = (global char*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||||
|
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||||
|
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||||
|
|
||||||
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||||
|
const half x0 = src0_row[i0];
|
||||||
|
const half x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_erf*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// geglu_quick
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
kernel void kernel_geglu_quick(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb11,
|
||||||
|
int ne0,
|
||||||
|
ulong nb1,
|
||||||
|
int ne00_off,
|
||||||
|
int ne10_off
|
||||||
|
) {
|
||||||
|
src0 = (global char*)((global char*)src0 + offset0);
|
||||||
|
src1 = (global char*)((global char*)src1 + offset1);
|
||||||
|
dst = (global char*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||||
|
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||||
|
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||||
|
|
||||||
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_quick*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu_quick_f16(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb11,
|
||||||
|
int ne0,
|
||||||
|
ulong nb1,
|
||||||
|
int ne00_off,
|
||||||
|
int ne10_off
|
||||||
|
) {
|
||||||
|
src0 = (global char*)((global char*)src0 + offset0);
|
||||||
|
src1 = (global char*)((global char*)src1 + offset1);
|
||||||
|
dst = (global char*)((global char*)dst + offsetd);
|
||||||
|
|
||||||
|
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||||
|
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||||
|
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||||
|
|
||||||
|
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||||
|
const half x0 = src0_row[i0];
|
||||||
|
const half x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu_quick*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -240,6 +240,21 @@ enum vk_device_architecture {
|
||||||
INTEL_XE2,
|
INTEL_XE2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// HSK x HSV
|
||||||
|
enum FaHeadSizes {
|
||||||
|
FA_HEAD_SIZE_64,
|
||||||
|
FA_HEAD_SIZE_80,
|
||||||
|
FA_HEAD_SIZE_96,
|
||||||
|
FA_HEAD_SIZE_112,
|
||||||
|
FA_HEAD_SIZE_128,
|
||||||
|
FA_HEAD_SIZE_192,
|
||||||
|
FA_HEAD_SIZE_192_128,
|
||||||
|
FA_HEAD_SIZE_256,
|
||||||
|
FA_HEAD_SIZE_576_512,
|
||||||
|
FA_HEAD_SIZE_UNSUPPORTED,
|
||||||
|
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
|
||||||
|
};
|
||||||
|
|
||||||
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
|
||||||
vk::PhysicalDeviceProperties props = device.getProperties();
|
vk::PhysicalDeviceProperties props = device.getProperties();
|
||||||
|
|
||||||
|
@ -457,6 +472,8 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_geglu[2];
|
vk_pipeline pipeline_geglu[2];
|
||||||
vk_pipeline pipeline_reglu[2];
|
vk_pipeline pipeline_reglu[2];
|
||||||
vk_pipeline pipeline_swiglu[2];
|
vk_pipeline pipeline_swiglu[2];
|
||||||
|
vk_pipeline pipeline_geglu_erf[2];
|
||||||
|
vk_pipeline pipeline_geglu_quick[2];
|
||||||
|
|
||||||
vk_pipeline pipeline_leaky_relu_f32;
|
vk_pipeline pipeline_leaky_relu_f32;
|
||||||
vk_pipeline pipeline_silu_back_f32;
|
vk_pipeline pipeline_silu_back_f32;
|
||||||
|
@ -483,26 +500,11 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||||
|
|
||||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
||||||
|
|
||||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||||
|
|
||||||
|
@ -649,6 +651,7 @@ struct vk_flash_attn_push_constants {
|
||||||
uint32_t nev2;
|
uint32_t nev2;
|
||||||
uint32_t nev3;
|
uint32_t nev3;
|
||||||
uint32_t nem1;
|
uint32_t nem1;
|
||||||
|
uint32_t nem2;
|
||||||
|
|
||||||
uint32_t nb01;
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
|
@ -659,7 +662,6 @@ struct vk_flash_attn_push_constants {
|
||||||
uint32_t nb21;
|
uint32_t nb21;
|
||||||
uint32_t nb22;
|
uint32_t nb22;
|
||||||
uint32_t nb23;
|
uint32_t nb23;
|
||||||
uint32_t nb31;
|
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
@ -674,6 +676,7 @@ struct vk_flash_attn_push_constants {
|
||||||
uint32_t split_kv;
|
uint32_t split_kv;
|
||||||
uint32_t k_num;
|
uint32_t k_num;
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
||||||
|
|
||||||
struct vk_op_push_constants {
|
struct vk_op_push_constants {
|
||||||
uint32_t KX;
|
uint32_t KX;
|
||||||
|
@ -772,6 +775,14 @@ struct vk_op_rope_push_constants {
|
||||||
struct vk_op_soft_max_push_constants {
|
struct vk_op_soft_max_push_constants {
|
||||||
uint32_t KX;
|
uint32_t KX;
|
||||||
uint32_t KY;
|
uint32_t KY;
|
||||||
|
uint32_t ne00;
|
||||||
|
uint32_t ne01;
|
||||||
|
uint32_t ne02;
|
||||||
|
uint32_t ne12;
|
||||||
|
uint32_t ne13;
|
||||||
|
uint32_t nb11;
|
||||||
|
uint32_t nb12;
|
||||||
|
uint32_t nb13;
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
float m0;
|
float m0;
|
||||||
|
@ -1010,7 +1021,7 @@ struct ggml_backend_vk_context {
|
||||||
|
|
||||||
// number of additional consecutive nodes that are being fused with the
|
// number of additional consecutive nodes that are being fused with the
|
||||||
// node currently being processed
|
// node currently being processed
|
||||||
uint32_t num_additional_fused_ops {};
|
int num_additional_fused_ops {};
|
||||||
};
|
};
|
||||||
|
|
||||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||||
|
@ -1706,6 +1717,35 @@ enum FaCodePath {
|
||||||
FA_COOPMAT2,
|
FA_COOPMAT2,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
|
||||||
|
if (hsk != 192 && hsk != 576 && hsk != hsv) {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
switch (hsk) {
|
||||||
|
case 64: return FA_HEAD_SIZE_64;
|
||||||
|
case 80: return FA_HEAD_SIZE_80;
|
||||||
|
case 96: return FA_HEAD_SIZE_96;
|
||||||
|
case 112: return FA_HEAD_SIZE_112;
|
||||||
|
case 128: return FA_HEAD_SIZE_128;
|
||||||
|
case 192:
|
||||||
|
if (hsv == 192) {
|
||||||
|
return FA_HEAD_SIZE_192;
|
||||||
|
} else if (hsv == 128) {
|
||||||
|
return FA_HEAD_SIZE_192_128;
|
||||||
|
} else {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
case 256: return FA_HEAD_SIZE_256;
|
||||||
|
case 576:
|
||||||
|
if (hsv == 512) {
|
||||||
|
return FA_HEAD_SIZE_576_512;
|
||||||
|
} else {
|
||||||
|
return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
default: return FA_HEAD_SIZE_UNSUPPORTED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// number of rows/cols for flash attention shader
|
// number of rows/cols for flash attention shader
|
||||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||||
|
@ -1726,8 +1766,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||||
GGML_UNUSED(clamp);
|
GGML_UNUSED(clamp);
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
|
|
||||||
if (path == FA_SCALAR) {
|
if (path == FA_SCALAR) {
|
||||||
if (small_rows) {
|
if (small_rows) {
|
||||||
|
@ -1751,7 +1792,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
|
||||||
}
|
}
|
||||||
|
|
||||||
// small cols to reduce register count
|
// small cols to reduce register count
|
||||||
if (ggml_is_quantized(type) || D == 256) {
|
if (ggml_is_quantized(type) || hsk >= 256) {
|
||||||
return {64, 32};
|
return {64, 32};
|
||||||
}
|
}
|
||||||
return {64, 64};
|
return {64, 64};
|
||||||
|
@ -2044,19 +2085,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||||
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||||
};
|
};
|
||||||
|
|
||||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||||
// For large number of rows, 128 invocations seems to work best.
|
// For large number of rows, 128 invocations seems to work best.
|
||||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||||
// can't use 256 for D==80.
|
// can't use 256 for D==80.
|
||||||
// For scalar, use 128 (arbitrary)
|
// For scalar, use 128 (arbitrary)
|
||||||
|
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||||
|
const uint32_t D = (hsk|hsv);
|
||||||
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
||||||
? scalar_flash_attention_workgroup_size
|
? scalar_flash_attention_workgroup_size
|
||||||
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||||
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
|
||||||
|
|
||||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||||
|
@ -2065,26 +2108,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||||
};
|
};
|
||||||
|
|
||||||
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||||
|
|
||||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
|
||||||
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
|
||||||
|
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
|
||||||
|
|
||||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||||
|
@ -2793,6 +2839,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_GLU(geglu)
|
CREATE_GLU(geglu)
|
||||||
CREATE_GLU(reglu)
|
CREATE_GLU(reglu)
|
||||||
CREATE_GLU(swiglu)
|
CREATE_GLU(swiglu)
|
||||||
|
CREATE_GLU(geglu_erf)
|
||||||
|
CREATE_GLU(geglu_quick)
|
||||||
#undef CREATE_GLU
|
#undef CREATE_GLU
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -3703,7 +3751,6 @@ static void ggml_vk_instance_init() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
|
||||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||||
|
|
||||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||||
|
@ -6017,24 +6064,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
|
||||||
// Needs to be kept up to date on shader changes
|
// Needs to be kept up to date on shader changes
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||||
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
||||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||||
|
|
||||||
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||||
|
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||||
|
|
||||||
|
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||||
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||||
|
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||||
|
// Needs to be kept up to date on shader changes
|
||||||
|
GGML_UNUSED(hsv);
|
||||||
|
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||||
|
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
|
||||||
|
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||||
|
|
||||||
const uint32_t acctype = f32acc ? 4 : 2;
|
const uint32_t acctype = f32acc ? 4 : 2;
|
||||||
const uint32_t f16vec4 = 8;
|
const uint32_t f16vec4 = 8;
|
||||||
|
|
||||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||||
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
||||||
|
|
||||||
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
|
||||||
|
|
||||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||||
|
|
||||||
const uint32_t kshstride = D / 4 + 2;
|
const uint32_t kshstride = hsk / 4 + 2;
|
||||||
const uint32_t ksh = Bc * kshstride * f16vec4;
|
const uint32_t ksh = Bc * kshstride * f16vec4;
|
||||||
|
|
||||||
const uint32_t slope = Br * sizeof(float);
|
const uint32_t slope = Br * sizeof(float);
|
||||||
|
@ -6042,7 +6112,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||||
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
||||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||||
|
|
||||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||||
|
|
||||||
return supported;
|
return supported;
|
||||||
}
|
}
|
||||||
|
@ -6064,13 +6134,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
||||||
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
||||||
|
|
||||||
const uint32_t D = neq0;
|
const uint32_t HSK = nek0;
|
||||||
|
const uint32_t HSV = nev0;
|
||||||
uint32_t N = neq1;
|
uint32_t N = neq1;
|
||||||
const uint32_t KV = nek1;
|
const uint32_t KV = nek1;
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == D);
|
GGML_ASSERT(ne0 == HSV);
|
||||||
GGML_ASSERT(ne2 == N);
|
GGML_ASSERT(ne2 == N);
|
||||||
|
|
||||||
// input tensor rows must be contiguous
|
// input tensor rows must be contiguous
|
||||||
|
@ -6078,12 +6149,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||||
|
|
||||||
GGML_ASSERT(neq0 == D);
|
GGML_ASSERT(neq0 == HSK);
|
||||||
GGML_ASSERT(nek0 == D);
|
|
||||||
GGML_ASSERT(nev0 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(neq1 == N);
|
GGML_ASSERT(neq1 == N);
|
||||||
GGML_ASSERT(nev0 == D);
|
|
||||||
|
|
||||||
GGML_ASSERT(nev1 == nek1);
|
GGML_ASSERT(nev1 == nek1);
|
||||||
|
|
||||||
|
@ -6104,7 +6172,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||||
|
|
||||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
|
||||||
|
|
||||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||||
path = FA_SCALAR;
|
path = FA_SCALAR;
|
||||||
|
@ -6157,47 +6225,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
path = FA_SCALAR;
|
path = FA_SCALAR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||||
|
if (path == FA_SCALAR &&
|
||||||
|
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
|
||||||
|
small_rows = true;
|
||||||
|
}
|
||||||
|
|
||||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||||
|
|
||||||
|
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
|
||||||
|
|
||||||
switch (path) {
|
switch (path) {
|
||||||
case FA_SCALAR:
|
case FA_SCALAR:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case FA_COOPMAT1:
|
case FA_COOPMAT1:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
case FA_COOPMAT2:
|
case FA_COOPMAT2:
|
||||||
switch (D) {
|
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
|
||||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(!"unsupported D value");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(0);
|
GGML_ASSERT(0);
|
||||||
|
@ -6227,7 +6273,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
// Try to use split_k when KV is large enough to be worth the overhead
|
// Try to use split_k when KV is large enough to be worth the overhead
|
||||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||||
// Try to run two workgroups per SM.
|
// Try to run two workgroups per SM.
|
||||||
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||||
if (split_k > 1) {
|
if (split_k > 1) {
|
||||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||||
// of "align", so recompute split_k based on that.
|
// of "align", so recompute split_k based on that.
|
||||||
|
@ -6237,9 +6283,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||||
// and the per-row m and L values (ne1 rows).
|
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||||
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||||
GGML_ABORT("Requested preallocation size is too large");
|
GGML_ABORT("Requested preallocation size is too large");
|
||||||
}
|
}
|
||||||
|
@ -6331,11 +6377,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
(uint32_t)neq2, (uint32_t)neq3,
|
(uint32_t)neq2, (uint32_t)neq3,
|
||||||
(uint32_t)nek2, (uint32_t)nek3,
|
(uint32_t)nek2, (uint32_t)nek3,
|
||||||
(uint32_t)nev2, (uint32_t)nev3,
|
(uint32_t)nev2, (uint32_t)nev3,
|
||||||
nem1,
|
nem1, nem2,
|
||||||
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
||||||
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
||||||
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
||||||
nbm1,
|
|
||||||
scale, max_bias, logit_softcap,
|
scale, max_bias, logit_softcap,
|
||||||
mask != nullptr, n_head_log2, m0, m1,
|
mask != nullptr, n_head_log2, m0, m1,
|
||||||
gqa_ratio, split_kv, split_k };
|
gqa_ratio, split_kv, split_k };
|
||||||
|
@ -6358,13 +6403,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||||
|
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||||
{
|
{
|
||||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||||
},
|
},
|
||||||
pc2, { (uint32_t)ne1, 1, 1 });
|
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||||
{
|
{
|
||||||
|
@ -6558,6 +6603,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -7690,7 +7739,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
||||||
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
||||||
|
|
||||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
||||||
|
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
||||||
|
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
||||||
|
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
||||||
|
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
||||||
|
|
||||||
|
const uint32_t n_head_kv = src0->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
@ -7699,6 +7754,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
||||||
ncols,
|
ncols,
|
||||||
src1 != nullptr ? nrows_y : (uint32_t)0,
|
src1 != nullptr ? nrows_y : (uint32_t)0,
|
||||||
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
||||||
|
ne12, ne13,
|
||||||
|
nb11, nb12, nb13,
|
||||||
scale, max_bias,
|
scale, max_bias,
|
||||||
m0, m1,
|
m0, m1,
|
||||||
n_head_log2,
|
n_head_log2,
|
||||||
|
@ -8893,6 +8951,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -9140,6 +9200,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -9358,6 +9420,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -10168,6 +10232,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_GLU_OP_GEGLU:
|
case GGML_GLU_OP_GEGLU:
|
||||||
case GGML_GLU_OP_REGLU:
|
case GGML_GLU_OP_REGLU:
|
||||||
case GGML_GLU_OP_SWIGLU:
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous(op->src[0]) &&
|
return ggml_is_contiguous(op->src[0]) &&
|
||||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
@ -10248,19 +10314,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
auto device = ggml_vk_get_device(ctx->device);
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
bool coopmat2 = device->coopmat2;
|
bool coopmat2 = device->coopmat2;
|
||||||
switch (op->src[0]->ne[0]) {
|
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
|
||||||
case 64:
|
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
|
||||||
case 80:
|
|
||||||
case 96:
|
|
||||||
case 112:
|
|
||||||
case 128:
|
|
||||||
case 256:
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
||||||
// different head sizes of K and V are not supported yet
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||||
|
@ -10272,6 +10327,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TODO: support broadcast
|
||||||
|
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
|
||||||
|
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||||
|
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// It's straightforward to support different K/V dequant, but would
|
// It's straightforward to support different K/V dequant, but would
|
||||||
// significantly increase the number of pipelines
|
// significantly increase the number of pipelines
|
||||||
if (op->src[1]->type != op->src[2]->type) {
|
if (op->src[1]->type != op->src[2]->type) {
|
||||||
|
@ -10340,6 +10401,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SET_ROWS:
|
||||||
|
{
|
||||||
|
// TODO: add support
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
@ -10430,6 +10497,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
return true;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
|
|
|
@ -11,7 +11,8 @@
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
#include "flash_attn_base.comp"
|
#include "flash_attn_base.comp"
|
||||||
|
|
||||||
const uint32_t D_per_thread = D / D_split;
|
const uint32_t HSK_per_thread = HSK / D_split;
|
||||||
|
const uint32_t HSV_per_thread = HSV / D_split;
|
||||||
|
|
||||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||||
|
@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
|
@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||||
shared vec4 tmpshv4[WorkGroupSize];
|
shared vec4 tmpshv4[WorkGroupSize];
|
||||||
|
|
||||||
shared float masksh[Bc][Br];
|
shared float masksh[Bc][Br];
|
||||||
shared vec4 Qf[Br][D / 4];
|
shared vec4 Qf[Br][HSK / 4];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
|
@ -53,18 +54,18 @@ void main() {
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t r = (idx + tid) / (D / 4);
|
uint32_t r = (idx + tid) / (HSK / 4);
|
||||||
if (r < Br && d < D / 4 &&
|
if (r < Br && d < HSK / 4 &&
|
||||||
i * Br + r < N) {
|
i * Br + r < N) {
|
||||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
vec4 Of[Br][D_per_thread / 4];
|
vec4 Of[Br][HSV_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = vec4(0.0);
|
Of[r][d] = vec4(0.0);
|
||||||
}
|
}
|
||||||
|
@ -99,6 +100,10 @@ void main() {
|
||||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||||
#endif
|
#endif
|
||||||
|
uint32_t m_offset = 0;
|
||||||
|
if (p.nem2 != 1) {
|
||||||
|
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
||||||
|
}
|
||||||
|
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
@ -112,7 +117,7 @@ void main() {
|
||||||
|
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
@ -150,7 +155,7 @@ void main() {
|
||||||
uint32_t c = (idx + tid) % Bc;
|
uint32_t c = (idx + tid) % Bc;
|
||||||
uint32_t r = (idx + tid) / Bc;
|
uint32_t r = (idx + tid) / Bc;
|
||||||
if (idx + tid < Bc * Br) {
|
if (idx + tid < Bc * Br) {
|
||||||
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
@ -191,14 +196,14 @@ void main() {
|
||||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] = eMf[r] * Of[r][d];
|
Of[r][d] = eMf[r] * Of[r][d];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
@ -255,7 +260,7 @@ void main() {
|
||||||
Lf[r] = tmpsh[d_tid];
|
Lf[r] = tmpsh[d_tid];
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
||||||
Of[r][d] = eMf * Of[r][d];
|
Of[r][d] = eMf * Of[r][d];
|
||||||
tmpshv4[tid] = Of[r][d];
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
@ -277,11 +282,11 @@ void main() {
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -289,7 +294,7 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||||
|
@ -305,18 +310,18 @@ void main() {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
Of[r][d] *= Lfrcp[r];
|
Of[r][d] *= Lfrcp[r];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (r < N) {
|
if (r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -326,9 +331,9 @@ void main() {
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
if (i * Br + r < N) {
|
if (i * Br + r < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||||
layout (constant_id = 1) const uint32_t Br = 1;
|
layout (constant_id = 1) const uint32_t Br = 1;
|
||||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||||
layout (constant_id = 3) const uint32_t D = 32;
|
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||||
layout (constant_id = 4) const uint32_t Clamp = 0;
|
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||||
layout (constant_id = 5) const uint32_t D_split = 16;
|
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||||
|
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint32_t N;
|
uint32_t N;
|
||||||
|
@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
|
||||||
uint32_t nev2;
|
uint32_t nev2;
|
||||||
uint32_t nev3;
|
uint32_t nev3;
|
||||||
uint32_t nem1;
|
uint32_t nem1;
|
||||||
|
uint32_t nem2;
|
||||||
|
|
||||||
uint32_t nb01;
|
uint32_t nb01;
|
||||||
uint32_t nb02;
|
uint32_t nb02;
|
||||||
|
@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
|
||||||
uint32_t nb21;
|
uint32_t nb21;
|
||||||
uint32_t nb22;
|
uint32_t nb22;
|
||||||
uint32_t nb23;
|
uint32_t nb23;
|
||||||
uint32_t nb31;
|
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
#include "flash_attn_base.comp"
|
#include "flash_attn_base.comp"
|
||||||
|
|
||||||
const uint32_t D_per_thread = D / D_split;
|
const uint32_t HSK_per_thread = HSK / D_split;
|
||||||
|
const uint32_t HSV_per_thread = HSV / D_split;
|
||||||
|
|
||||||
const uint32_t row_split = 4;
|
const uint32_t row_split = 4;
|
||||||
const uint32_t rows_per_thread = Br / row_split;
|
const uint32_t rows_per_thread = Br / row_split;
|
||||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
||||||
|
@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
|
@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
||||||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||||
|
|
||||||
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||||
shared f16vec4 Qf[Br * qstride];
|
shared f16vec4 Qf[Br * qstride];
|
||||||
|
|
||||||
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||||
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||||
|
|
||||||
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||||
shared f16vec4 ksh[Bc * kshstride];
|
shared f16vec4 ksh[Bc * kshstride];
|
||||||
|
|
||||||
shared float slope[Br];
|
shared float slope[Br];
|
||||||
|
@ -74,18 +76,18 @@ void main() {
|
||||||
|
|
||||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t r = (idx + tid) / (D / 4);
|
uint32_t r = (idx + tid) / (HSK / 4);
|
||||||
if (r < Br && d < D / 4 &&
|
if (r < Br && d < HSK / 4 &&
|
||||||
i * Br + r < N) {
|
i * Br + r < N) {
|
||||||
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] = ACC_TYPEV4(0.0);
|
Of[r][d] = ACC_TYPEV4(0.0);
|
||||||
}
|
}
|
||||||
|
@ -123,14 +125,18 @@ void main() {
|
||||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||||
#endif
|
#endif
|
||||||
|
uint32_t m_offset = 0;
|
||||||
|
if (p.nem2 != 1) {
|
||||||
|
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
||||||
|
}
|
||||||
|
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||||
uint32_t d = (idx + tid) % (D / 4);
|
uint32_t d = (idx + tid) % (HSK / 4);
|
||||||
uint32_t c = (idx + tid) / (D / 4);
|
uint32_t c = (idx + tid) / (HSK / 4);
|
||||||
if (c < Bc && d < D / 4) {
|
if (c < Bc && d < HSK / 4) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
@ -145,14 +151,14 @@ void main() {
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||||
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||||
|
|
||||||
for (uint32_t d = 0; d < D / 16; ++d) {
|
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||||
|
@ -181,7 +187,7 @@ void main() {
|
||||||
uint32_t c = (idx + tid) % Bc;
|
uint32_t c = (idx + tid) % Bc;
|
||||||
uint32_t r = (idx + tid) / Bc;
|
uint32_t r = (idx + tid) / Bc;
|
||||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
|
@ -202,7 +208,7 @@ void main() {
|
||||||
eMf[r] = exp(Moldf - Mf[r]);
|
eMf[r] = exp(Moldf - Mf[r]);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
}
|
}
|
||||||
|
@ -217,7 +223,7 @@ void main() {
|
||||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||||
Lf[r] += Pf[r];
|
Lf[r] += Pf[r];
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
@ -280,7 +286,7 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
|
|
||||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||||
tmpshv4[tid] = Of[r][d];
|
tmpshv4[tid] = Of[r][d];
|
||||||
|
@ -300,11 +306,11 @@ void main() {
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -312,7 +318,7 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||||
|
@ -328,18 +334,18 @@ void main() {
|
||||||
Lfrcp[r] = 1.0 / Lf[r];
|
Lfrcp[r] = 1.0 / Lf[r];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (tile_row(r) < N) {
|
if (tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
||||||
}
|
}
|
||||||
|
@ -349,9 +355,9 @@ void main() {
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
if (i * Br + tile_row(r) < N) {
|
if (i * Br + tile_row(r) < N) {
|
||||||
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||||
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||||
{
|
{
|
||||||
if (r < N && c < D) {
|
if (r < N && c < HSV) {
|
||||||
uint32_t offset = (iq2 + r) * D + c;
|
uint32_t offset = (iq2 + r) * HSV + c;
|
||||||
data_o[o_offset + offset] = D_TYPE(elem);
|
data_o[o_offset + offset] = D_TYPE(elem);
|
||||||
}
|
}
|
||||||
return elem;
|
return elem;
|
||||||
|
@ -86,9 +86,9 @@ void main() {
|
||||||
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
||||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
||||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
|
||||||
|
|
||||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||||
|
@ -104,16 +104,16 @@ void main() {
|
||||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||||
|
|
||||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||||
|
|
||||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||||
|
|
||||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||||
Qf16 *= float16_t(p.scale);
|
Qf16 *= float16_t(p.scale);
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||||
|
|
||||||
|
@ -130,15 +130,20 @@ void main() {
|
||||||
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t m_offset = 0;
|
||||||
|
if (p.nem2 != 1) {
|
||||||
|
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||||
|
}
|
||||||
|
|
||||||
[[dont_unroll]]
|
[[dont_unroll]]
|
||||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||||
|
|
||||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||||
S = coopMatMulAdd(Qf16, K_T, S);
|
S = coopMatMulAdd(Qf16, K_T, S);
|
||||||
|
|
||||||
if (p.logit_softcap != 0.0f) {
|
if (p.logit_softcap != 0.0f) {
|
||||||
|
@ -155,7 +160,7 @@ void main() {
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||||
|
|
||||||
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
||||||
}
|
}
|
||||||
|
@ -203,42 +208,42 @@ void main() {
|
||||||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||||
|
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||||
|
|
||||||
L = eM*L + rowsum;
|
L = eM*L + rowsum;
|
||||||
|
|
||||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||||
// across the row
|
// across the row
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||||
|
|
||||||
// resize eM by using smear/reduce
|
// resize eM by using smear/reduce
|
||||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||||
|
|
||||||
// multiply with fp16 accumulation, then add to O.
|
// multiply with fp16 accumulation, then add to O.
|
||||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||||
PV = coopMatMulAdd(P_A, V, PV);
|
PV = coopMatMulAdd(P_A, V, PV);
|
||||||
|
|
||||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is split_k, then the split_k resolve shader does the final
|
// If there is split_k, then the split_k resolve shader does the final
|
||||||
// division by L. Store the intermediate O value and per-row m and L values.
|
// division by L. Store the intermediate O value and per-row m and L values.
|
||||||
if (p.k_num > 1) {
|
if (p.k_num > 1) {
|
||||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||||
|
|
||||||
uint32_t o_offset = D * p.ne1 * split_k_index;
|
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||||
|
|
||||||
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||||
|
|
||||||
// resize L by using smear/reduce
|
// resize L by using smear/reduce
|
||||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||||
|
@ -250,18 +255,18 @@ void main() {
|
||||||
|
|
||||||
O = Ldiag*O;
|
O = Ldiag*O;
|
||||||
|
|
||||||
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||||
|
|
||||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||||
if (p.gqa_ratio > 1) {
|
if (p.gqa_ratio > 1) {
|
||||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||||
} else {
|
} else {
|
||||||
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
||||||
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
|
||||||
|
|
||||||
// permute dimensions
|
// permute dimensions
|
||||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||||
|
|
||||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint D;
|
uint D;
|
||||||
uint N;
|
uint N;
|
||||||
|
uint ne3;
|
||||||
uint k_num;
|
uint k_num;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
@ -19,13 +20,14 @@ void main() {
|
||||||
// Each workgroup handles a row
|
// Each workgroup handles a row
|
||||||
const uint n = gl_WorkGroupID.x;
|
const uint n = gl_WorkGroupID.x;
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
const uint iq3 = gl_WorkGroupID.z;
|
||||||
|
|
||||||
uint D = p.D;
|
uint D = p.D;
|
||||||
uint N = p.N;
|
uint N = p.N;
|
||||||
uint k_num = p.k_num;
|
uint k_num = p.k_num;
|
||||||
|
|
||||||
uint l_offset = D * N * k_num + n;
|
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
||||||
uint m_offset = D * N * k_num + N + n;
|
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
||||||
uint lm_stride = N * 2;
|
uint lm_stride = N * 2;
|
||||||
|
|
||||||
// Compute the max m value for the row
|
// Compute the max m value for the row
|
||||||
|
@ -49,11 +51,11 @@ void main() {
|
||||||
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
||||||
float O = 0.0;
|
float O = 0.0;
|
||||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||||
uint o_offset = D * N * k + D * n + d;
|
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||||
float m = data_a[m_offset + k * lm_stride];
|
float m = data_a[m_offset + k * lm_stride];
|
||||||
O += exp(m - m_max) * data_a[o_offset];
|
O += exp(m - m_max) * data_a[o_offset];
|
||||||
}
|
}
|
||||||
O *= L;
|
O *= L;
|
||||||
data_d[D * n + d] = O;
|
data_d[iq3 * D * N + D * n + d] = O;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
27
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
Normal file
27
ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||||
|
// ref: https://www.johndcook.com/blog/python_erf/
|
||||||
|
const float p_erf = 0.3275911f;
|
||||||
|
const float a1_erf = 0.254829592f;
|
||||||
|
const float a2_erf = -0.284496736f;
|
||||||
|
const float a3_erf = 1.421413741f;
|
||||||
|
const float a4_erf = -1.453152027f;
|
||||||
|
const float a5_erf = 1.061405429f;
|
||||||
|
|
||||||
|
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
const float a_div_sqr2 = a * SQRT_2_INV;
|
||||||
|
const float sign_x = sign(a_div_sqr2);
|
||||||
|
const float x = abs(a_div_sqr2);
|
||||||
|
const float t = 1.0f / (1.0f + p_erf * x);
|
||||||
|
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||||
|
const float erf_approx = sign_x * y;
|
||||||
|
|
||||||
|
return 0.5f * a * (1.0f + erf_approx) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
11
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
Normal file
11
ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
|
@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint KX;
|
uint KX;
|
||||||
uint KY;
|
uint KY;
|
||||||
|
uint ne00;
|
||||||
|
uint ne01;
|
||||||
|
uint ne02;
|
||||||
|
uint ne12;
|
||||||
|
uint ne13;
|
||||||
|
uint nb11;
|
||||||
|
uint nb12;
|
||||||
|
uint nb13;
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
float m0;
|
float m0;
|
||||||
|
@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||||
void soft_max(uint num_iters) {
|
void soft_max(uint num_iters) {
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
|
|
||||||
|
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
||||||
|
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
||||||
|
const uint32_t i01 = rowx % p.ne01;
|
||||||
|
|
||||||
|
uint rowy_start = 0;
|
||||||
|
if (p.KY > 0) {
|
||||||
|
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
||||||
|
}
|
||||||
|
|
||||||
if (rowx >= p.nrows_x) {
|
if (rowx >= p.nrows_x) {
|
||||||
return;
|
return;
|
||||||
|
@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (p.max_bias > 0.0f) {
|
if (p.max_bias > 0.0f) {
|
||||||
const uint h = rowx/p.KY; // head index
|
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
||||||
|
|
||||||
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
||||||
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
||||||
|
@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
|
||||||
|
|
||||||
FLOAT_TYPE b = FLOAT_TYPE(0);
|
FLOAT_TYPE b = FLOAT_TYPE(0);
|
||||||
if (p.KY > 0 && col < p.KX) {
|
if (p.KY > 0 && col < p.KX) {
|
||||||
b = data_b[rowy * p.KX + col];
|
b = data_b[rowy_start + col];
|
||||||
}
|
}
|
||||||
|
|
||||||
FLOAT_TYPE v = a * p.scale + slope * b;
|
FLOAT_TYPE v = a * p.scale + slope * b;
|
||||||
|
@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
|
||||||
if (idx < DATA_CACHE_SIZE) {
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
val = exp(data_cache[idx] - max_val);
|
val = exp(data_cache[idx] - max_val);
|
||||||
} else {
|
} else {
|
||||||
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||||
}
|
}
|
||||||
sum += val;
|
sum += val;
|
||||||
if (idx < DATA_CACHE_SIZE) {
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
|
|
|
@ -607,6 +607,10 @@ void process_shaders() {
|
||||||
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
134
ggml/src/ggml.c
134
ggml/src/ggml.c
|
@ -21,6 +21,9 @@
|
||||||
#include <alloca.h>
|
#include <alloca.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GGML_VERSION "0.0.1"
|
||||||
|
#define GGML_COMMIT "KCPP"
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <errno.h>
|
#include <errno.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
@ -474,6 +477,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
|
||||||
return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
|
return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * ggml_version(void) {
|
||||||
|
return GGML_VERSION;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * ggml_commit(void) {
|
||||||
|
return GGML_COMMIT;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// timing
|
// timing
|
||||||
//
|
//
|
||||||
|
@ -1145,9 +1156,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||||
"REGLU",
|
"REGLU",
|
||||||
"GEGLU",
|
"GEGLU",
|
||||||
"SWIGLU",
|
"SWIGLU",
|
||||||
|
"GEGLU_ERF",
|
||||||
|
"GEGLU_QUICK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
|
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
|
||||||
|
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
@ -2773,6 +2786,48 @@ struct ggml_tensor * ggml_swiglu_split(
|
||||||
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_geglu_erf
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_erf(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_erf_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_erf_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_geglu_quick
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_quick(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_quick_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_quick_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_norm
|
// ggml_norm
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_norm_impl(
|
static struct ggml_tensor * ggml_norm_impl(
|
||||||
|
@ -3679,9 +3734,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(ggml_is_matrix(mask));
|
|
||||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
|
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
||||||
|
GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
@ -4702,13 +4758,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
|
GGML_ASSERT(q->ne[3] == k->ne[3]);
|
||||||
|
GGML_ASSERT(q->ne[3] == v->ne[3]);
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
|
||||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
|
|
||||||
|
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
||||||
|
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
|
@ -4836,7 +4896,6 @@ struct ggml_tensor * ggml_ssm_conv(
|
||||||
const int64_t n_s = sx->ne[2];
|
const int64_t n_s = sx->ne[2];
|
||||||
|
|
||||||
// TODO: maybe support other strides than 1?
|
// TODO: maybe support other strides than 1?
|
||||||
// FIXME: this is always true?
|
|
||||||
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||||
GGML_ASSERT(sx->ne[1] == d_inner);
|
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||||
GGML_ASSERT(n_t >= 0);
|
GGML_ASSERT(n_t >= 0);
|
||||||
|
@ -4859,36 +4918,49 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C) {
|
struct ggml_tensor * C,
|
||||||
|
struct ggml_tensor * ids) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(s));
|
GGML_ASSERT(ggml_is_contiguous(s));
|
||||||
GGML_ASSERT(ggml_is_contiguous(x));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
GGML_ASSERT(ggml_is_contiguous(A));
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
GGML_ASSERT(ggml_is_matrix(A));
|
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
|
||||||
GGML_ASSERT(ggml_is_3d(B));
|
|
||||||
GGML_ASSERT(ggml_is_3d(s));
|
|
||||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
|
||||||
|
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
||||||
|
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
||||||
GGML_ASSERT(ggml_are_same_shape(B, C));
|
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||||
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t d_state = s->ne[0];
|
const int64_t d_state = s->ne[0];
|
||||||
const int64_t d_inner = s->ne[1];
|
const int64_t head_dim = x->ne[0];
|
||||||
const int64_t n_seq_tokens = x->ne[1];
|
const int64_t n_head = x->ne[1];
|
||||||
const int64_t n_seqs = x->ne[2];
|
const int64_t n_seq_tokens = x->ne[2];
|
||||||
|
const int64_t n_seqs = x->ne[3];
|
||||||
|
|
||||||
GGML_ASSERT(s->ne[2] == n_seqs);
|
GGML_ASSERT(dt->ne[0] == n_head);
|
||||||
GGML_ASSERT(x->ne[0] == d_inner);
|
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
|
||||||
GGML_ASSERT(A->ne[0] == d_state);
|
GGML_ASSERT(dt->ne[2] == n_seqs);
|
||||||
GGML_ASSERT(A->ne[1] == d_inner);
|
GGML_ASSERT(ggml_is_3d(dt));
|
||||||
|
GGML_ASSERT(s->ne[1] == head_dim);
|
||||||
|
GGML_ASSERT(s->ne[2] == n_head);
|
||||||
GGML_ASSERT(B->ne[0] == d_state);
|
GGML_ASSERT(B->ne[0] == d_state);
|
||||||
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
||||||
GGML_ASSERT(B->ne[2] == n_seqs);
|
GGML_ASSERT(B->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(ids->ne[0] == n_seqs);
|
||||||
|
GGML_ASSERT(ggml_is_vector(ids));
|
||||||
|
GGML_ASSERT(A->ne[1] == n_head);
|
||||||
|
GGML_ASSERT(ggml_is_matrix(A));
|
||||||
|
|
||||||
|
if (A->ne[0] != 1) {
|
||||||
|
// Mamba-1 has more granular decay factors
|
||||||
|
GGML_ASSERT(A->ne[0] == d_state);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// concatenated y + ssm_states
|
// concatenated y + ssm_states
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
result->src[0] = s;
|
result->src[0] = s;
|
||||||
|
@ -4897,6 +4969,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
result->src[3] = A;
|
result->src[3] = A;
|
||||||
result->src[4] = B;
|
result->src[4] = B;
|
||||||
result->src[5] = C;
|
result->src[5] = C;
|
||||||
|
result->src[6] = ids;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -6037,13 +6110,28 @@ static void ggml_compute_backward(
|
||||||
}
|
}
|
||||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GLU: {
|
||||||
|
switch (ggml_get_glu_op(tensor)) {
|
||||||
|
case GGML_GLU_OP_SWIGLU: {
|
||||||
|
if (src0_needs_grads) {
|
||||||
|
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
||||||
|
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
|
||||||
|
}
|
||||||
|
if (src1_needs_grads) {
|
||||||
|
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
|
||||||
|
} //break;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_NONE: {
|
case GGML_OP_NONE: {
|
||||||
// noop
|
// noop
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_COUNT:
|
case GGML_OP_COUNT:
|
||||||
default: {
|
default: {
|
||||||
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} //break;
|
} //break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -170,6 +170,7 @@ class Keys:
|
||||||
INNER_SIZE = "{arch}.ssm.inner_size"
|
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||||
STATE_SIZE = "{arch}.ssm.state_size"
|
STATE_SIZE = "{arch}.ssm.state_size"
|
||||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||||
|
GROUP_COUNT = "{arch}.ssm.group_count"
|
||||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||||
|
|
||||||
class WKV:
|
class WKV:
|
||||||
|
@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
RWKV7 = auto()
|
RWKV7 = auto()
|
||||||
ARWKV7 = auto()
|
ARWKV7 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
|
MAMBA2 = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
COHERE2 = auto()
|
COHERE2 = auto()
|
||||||
|
@ -429,6 +431,7 @@ class MODEL_TENSOR(IntEnum):
|
||||||
SSM_DT = auto()
|
SSM_DT = auto()
|
||||||
SSM_A = auto()
|
SSM_A = auto()
|
||||||
SSM_D = auto()
|
SSM_D = auto()
|
||||||
|
SSM_NORM = auto()
|
||||||
SSM_OUT = auto()
|
SSM_OUT = auto()
|
||||||
TIME_MIX_W0 = auto()
|
TIME_MIX_W0 = auto()
|
||||||
TIME_MIX_W1 = auto()
|
TIME_MIX_W1 = auto()
|
||||||
|
@ -628,6 +631,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.RWKV7: "rwkv7",
|
MODEL_ARCH.RWKV7: "rwkv7",
|
||||||
MODEL_ARCH.ARWKV7: "arwkv7",
|
MODEL_ARCH.ARWKV7: "arwkv7",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
|
MODEL_ARCH.MAMBA2: "mamba2",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
MODEL_ARCH.COHERE2: "cohere2",
|
MODEL_ARCH.COHERE2: "cohere2",
|
||||||
|
@ -730,6 +734,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||||
|
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
|
||||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||||
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
|
||||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||||
|
@ -1714,6 +1719,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.SSM_D,
|
MODEL_TENSOR.SSM_D,
|
||||||
MODEL_TENSOR.SSM_OUT,
|
MODEL_TENSOR.SSM_OUT,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.MAMBA2: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.SSM_IN,
|
||||||
|
MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
MODEL_TENSOR.SSM_DT,
|
||||||
|
MODEL_TENSOR.SSM_A,
|
||||||
|
MODEL_TENSOR.SSM_D,
|
||||||
|
MODEL_TENSOR.SSM_NORM,
|
||||||
|
MODEL_TENSOR.SSM_OUT,
|
||||||
|
],
|
||||||
MODEL_ARCH.XVERSE: [
|
MODEL_ARCH.XVERSE: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
@ -2497,6 +2515,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
||||||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||||
|
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
|
||||||
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
||||||
|
|
||||||
# tokenization
|
# tokenization
|
||||||
|
|
|
@ -714,8 +714,8 @@ class GGUFWriter:
|
||||||
def add_clamp_kqv(self, value: float) -> None:
|
def add_clamp_kqv(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_shared_kv_layers(self, value: float) -> None:
|
def add_shared_kv_layers(self, value: int) -> None:
|
||||||
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
|
||||||
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
|
||||||
|
@ -861,6 +861,9 @@ class GGUFWriter:
|
||||||
def add_ssm_time_step_rank(self, value: int) -> None:
|
def add_ssm_time_step_rank(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_group_count(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
@ -477,7 +477,7 @@ class TensorNameMap:
|
||||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||||
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
|
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||||
|
@ -574,6 +574,10 @@ class TensorNameMap:
|
||||||
"backbone.layers.{bid}.mixer.D",
|
"backbone.layers.{bid}.mixer.D",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_NORM: (
|
||||||
|
"backbone.layers.{bid}.mixer.norm", # mamba2
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_OUT: (
|
MODEL_TENSOR.SSM_OUT: (
|
||||||
"model.layers.{bid}.out_proj",
|
"model.layers.{bid}.out_proj",
|
||||||
"backbone.layers.{bid}.mixer.out_proj",
|
"backbone.layers.{bid}.mixer.out_proj",
|
||||||
|
|
|
@ -245,9 +245,18 @@ class SpecialVocab:
|
||||||
if not tokenizer_config:
|
if not tokenizer_config:
|
||||||
return True
|
return True
|
||||||
chat_template_alt = None
|
chat_template_alt = None
|
||||||
chat_template_file = path / 'chat_template.json'
|
chat_template_json = path / 'chat_template.json'
|
||||||
if chat_template_file.is_file():
|
chat_template_jinja = path / 'chat_template.jinja'
|
||||||
with open(chat_template_file, encoding = 'utf-8') as f:
|
if chat_template_jinja.is_file():
|
||||||
|
with open(chat_template_jinja, encoding = 'utf-8') as f:
|
||||||
|
chat_template_alt = f.read()
|
||||||
|
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
|
||||||
|
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
|
||||||
|
for template_path in additional_templates:
|
||||||
|
with open(template_path, encoding = 'utf-8') as fp:
|
||||||
|
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
|
||||||
|
elif chat_template_json.is_file():
|
||||||
|
with open(chat_template_json, encoding = 'utf-8') as f:
|
||||||
chat_template_alt = json.load(f).get('chat_template')
|
chat_template_alt = json.load(f).get('chat_template')
|
||||||
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
|
||||||
if chat_template is None or isinstance(chat_template, (str, list)):
|
if chat_template is None or isinstance(chat_template, (str, list)):
|
||||||
|
|
25
klite.embd
25
klite.embd
|
@ -3208,6 +3208,7 @@ Current version indicated by LITEVER below.
|
||||||
var schedule_multiplayer_major_change = false;
|
var schedule_multiplayer_major_change = false;
|
||||||
var last_request_str = "No Requests Available"; //full context of last submitted request
|
var last_request_str = "No Requests Available"; //full context of last submitted request
|
||||||
var last_response_obj = null;
|
var last_response_obj = null;
|
||||||
|
var last_response_streamlog = "";
|
||||||
var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp
|
var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp
|
||||||
var kai_poll_recoverykey = ""; //for recovering a lost polled streaming in case of disconnect.
|
var kai_poll_recoverykey = ""; //for recovering a lost polled streaming in case of disconnect.
|
||||||
var globalabortcontroller = null;
|
var globalabortcontroller = null;
|
||||||
|
@ -5634,6 +5635,10 @@ Current version indicated by LITEVER below.
|
||||||
},
|
},
|
||||||
transform(chunk, ctrl) {
|
transform(chunk, ctrl) {
|
||||||
ctrl.buf += chunk;
|
ctrl.buf += chunk;
|
||||||
|
if(chunk)
|
||||||
|
{
|
||||||
|
last_response_streamlog += chunk;
|
||||||
|
}
|
||||||
let evs = [];
|
let evs = [];
|
||||||
let m;
|
let m;
|
||||||
while ((m = /^data: ?(.*)(\r?\n){2}/m.exec(ctrl.buf)) !== null) {
|
while ((m = /^data: ?(.*)(\r?\n){2}/m.exec(ctrl.buf)) !== null) {
|
||||||
|
@ -9343,6 +9348,10 @@ Current version indicated by LITEVER below.
|
||||||
{
|
{
|
||||||
lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj);
|
lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj);
|
||||||
}
|
}
|
||||||
|
if(last_response_streamlog)
|
||||||
|
{
|
||||||
|
lr += "\n\nResponse:\n" + last_response_streamlog;
|
||||||
|
}
|
||||||
msgbox(lr,"Last Request Info",false);
|
msgbox(lr,"Last Request Info",false);
|
||||||
}
|
}
|
||||||
function show_last_logprobs()
|
function show_last_logprobs()
|
||||||
|
@ -10401,7 +10410,7 @@ Current version indicated by LITEVER below.
|
||||||
desired_oai_ep = transform_oai_ep(desired_oai_ep);
|
desired_oai_ep = transform_oai_ep(desired_oai_ep);
|
||||||
|
|
||||||
let oaiheaders = {};
|
let oaiheaders = {};
|
||||||
if(desired_oai_key!=""){
|
if(desired_oai_key!="" && !desired_oai_ep.toLowerCase().includes("pollinations.ai")){
|
||||||
oaiheaders["Authorization"] = "Bearer " + desired_oai_key;
|
oaiheaders["Authorization"] = "Bearer " + desired_oai_key;
|
||||||
};
|
};
|
||||||
if (desired_oai_ep.toLowerCase().includes("api.mistral.ai")) {
|
if (desired_oai_ep.toLowerCase().includes("api.mistral.ai")) {
|
||||||
|
@ -14041,6 +14050,7 @@ Current version indicated by LITEVER below.
|
||||||
redo_arr = [];
|
redo_arr = [];
|
||||||
last_request_str = "No Requests Available";
|
last_request_str = "No Requests Available";
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
retry_prev_text = [];
|
retry_prev_text = [];
|
||||||
retry_preserve_last = false;
|
retry_preserve_last = false;
|
||||||
retry_in_progress = false;
|
retry_in_progress = false;
|
||||||
|
@ -16471,6 +16481,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(submit_payload);
|
last_request_str = JSON.stringify(submit_payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) {
|
if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) {
|
||||||
let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint);
|
let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint);
|
||||||
kobold_api_stream_sse(sub_endpt, submit_payload);
|
kobold_api_stream_sse(sub_endpt, submit_payload);
|
||||||
|
@ -16643,11 +16654,15 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(oai_payload);
|
last_request_str = JSON.stringify(oai_payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
let oaiheaders = {
|
let oaiheaders = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json'
|
||||||
'Authorization': 'Bearer ' + custom_oai_key
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (!targetep.toLowerCase().includes("pollinations.ai")) {
|
||||||
|
oaiheaders['Authorization'] = 'Bearer ' + custom_oai_key;
|
||||||
|
}
|
||||||
|
|
||||||
if(targetep.toLowerCase().includes("openrouter.ai"))
|
if(targetep.toLowerCase().includes("openrouter.ai"))
|
||||||
{
|
{
|
||||||
oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net";
|
oaiheaders["HTTP-Referer"] = "https://lite.koboldai.net";
|
||||||
|
@ -16780,6 +16795,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(claude_payload);
|
last_request_str = JSON.stringify(claude_payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
|
|
||||||
let claudeheaders = {
|
let claudeheaders = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
@ -16970,6 +16986,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(payload);
|
last_request_str = JSON.stringify(payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
|
|
||||||
let geminiheaders = { 'Content-Type': 'application/json' };
|
let geminiheaders = { 'Content-Type': 'application/json' };
|
||||||
if(is_browser_supports_sse() && localsettings.tokenstreammode!=0)
|
if(is_browser_supports_sse() && localsettings.tokenstreammode!=0)
|
||||||
|
@ -17018,6 +17035,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(cohere_payload);
|
last_request_str = JSON.stringify(cohere_payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
let cohere_headers = {
|
let cohere_headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': 'Bearer ' + custom_cohere_key
|
'Authorization': 'Bearer ' + custom_cohere_key
|
||||||
|
@ -17093,6 +17111,7 @@ Current version indicated by LITEVER below.
|
||||||
|
|
||||||
last_request_str = JSON.stringify(submit_payload);
|
last_request_str = JSON.stringify(submit_payload);
|
||||||
last_response_obj = null;
|
last_response_obj = null;
|
||||||
|
last_response_streamlog = "";
|
||||||
|
|
||||||
fetch(horde_submit_endpoint, {
|
fetch(horde_submit_endpoint, {
|
||||||
method: 'POST', // or 'PUT'
|
method: 'POST', // or 'PUT'
|
||||||
|
|
|
@ -62,7 +62,7 @@ dry_seq_break_max = 128
|
||||||
extra_images_max = 4
|
extra_images_max = 4
|
||||||
|
|
||||||
# global vars
|
# global vars
|
||||||
KcppVersion = "1.95.1"
|
KcppVersion = "1.96"
|
||||||
showdebug = True
|
showdebug = True
|
||||||
kcpp_instance = None #global running instance
|
kcpp_instance = None #global running instance
|
||||||
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}
|
global_memory = {"tunnel_url": "", "restart_target":"", "input_to_exit":False, "load_complete":False, "restart_override_config_target":""}
|
||||||
|
|
|
@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
{ LLM_ARCH_MAMBA, "mamba" },
|
||||||
|
{ LLM_ARCH_MAMBA2, "mamba2" },
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
{ LLM_ARCH_XVERSE, "xverse" },
|
||||||
{ LLM_ARCH_COMMAND_R, "command-r" },
|
{ LLM_ARCH_COMMAND_R, "command-r" },
|
||||||
{ LLM_ARCH_COHERE2, "cohere2" },
|
{ LLM_ARCH_COHERE2, "cohere2" },
|
||||||
|
@ -170,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
||||||
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
||||||
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
||||||
|
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
|
||||||
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
||||||
|
|
||||||
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
||||||
|
@ -1004,6 +1006,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_MAMBA2,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||||
|
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||||
|
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||||
|
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||||
|
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||||
|
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||||
|
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_XVERSE,
|
LLM_ARCH_XVERSE,
|
||||||
{
|
{
|
||||||
|
@ -1761,6 +1779,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
||||||
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
||||||
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
@ -1894,6 +1913,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
|
case LLM_ARCH_MAMBA2:
|
||||||
case LLM_ARCH_RWKV6:
|
case LLM_ARCH_RWKV6:
|
||||||
case LLM_ARCH_RWKV6QWEN2:
|
case LLM_ARCH_RWKV6QWEN2:
|
||||||
case LLM_ARCH_RWKV7:
|
case LLM_ARCH_RWKV7:
|
||||||
|
|
|
@ -49,6 +49,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_GEMMA3N,
|
LLM_ARCH_GEMMA3N,
|
||||||
LLM_ARCH_STARCODER2,
|
LLM_ARCH_STARCODER2,
|
||||||
LLM_ARCH_MAMBA,
|
LLM_ARCH_MAMBA,
|
||||||
|
LLM_ARCH_MAMBA2,
|
||||||
LLM_ARCH_XVERSE,
|
LLM_ARCH_XVERSE,
|
||||||
LLM_ARCH_COMMAND_R,
|
LLM_ARCH_COMMAND_R,
|
||||||
LLM_ARCH_COHERE2,
|
LLM_ARCH_COHERE2,
|
||||||
|
@ -174,6 +175,7 @@ enum llm_kv {
|
||||||
LLM_KV_SSM_CONV_KERNEL,
|
LLM_KV_SSM_CONV_KERNEL,
|
||||||
LLM_KV_SSM_STATE_SIZE,
|
LLM_KV_SSM_STATE_SIZE,
|
||||||
LLM_KV_SSM_TIME_STEP_RANK,
|
LLM_KV_SSM_TIME_STEP_RANK,
|
||||||
|
LLM_KV_SSM_GROUP_COUNT,
|
||||||
LLM_KV_SSM_DT_B_C_RMS,
|
LLM_KV_SSM_DT_B_C_RMS,
|
||||||
|
|
||||||
LLM_KV_WKV_HEAD_SIZE,
|
LLM_KV_WKV_HEAD_SIZE,
|
||||||
|
@ -293,6 +295,7 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_SSM_DT,
|
LLM_TENSOR_SSM_DT,
|
||||||
LLM_TENSOR_SSM_A,
|
LLM_TENSOR_SSM_A,
|
||||||
LLM_TENSOR_SSM_D,
|
LLM_TENSOR_SSM_D,
|
||||||
|
LLM_TENSOR_SSM_NORM,
|
||||||
LLM_TENSOR_SSM_OUT,
|
LLM_TENSOR_SSM_OUT,
|
||||||
LLM_TENSOR_TIME_MIX_W0,
|
LLM_TENSOR_TIME_MIX_W0,
|
||||||
LLM_TENSOR_TIME_MIX_W1,
|
LLM_TENSOR_TIME_MIX_W1,
|
||||||
|
|
|
@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
||||||
|
|
||||||
// note: tracking the other way around is not necessary for now
|
// note: tracking the other way around is not necessary for now
|
||||||
//seq_cpl[s0][s1] = true;
|
//seq_cpl[s0][s1] = true;
|
||||||
|
|
||||||
|
has_cpl = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -404,6 +406,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
||||||
return n_outputs;
|
return n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_batch_allocr::get_n_used() const {
|
||||||
|
return n_used;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||||
return out_ids;
|
return out_ids;
|
||||||
}
|
}
|
||||||
|
@ -419,6 +425,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
void llama_batch_allocr::split_reset() {
|
void llama_batch_allocr::split_reset() {
|
||||||
out_ids.clear();
|
out_ids.clear();
|
||||||
|
|
||||||
|
n_used = 0;
|
||||||
|
|
||||||
used.clear();
|
used.clear();
|
||||||
used.resize(get_n_tokens(), false);
|
used.resize(get_n_tokens(), false);
|
||||||
|
|
||||||
|
@ -443,6 +451,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||||
idxs.push_back(cur_idx);
|
idxs.push_back(cur_idx);
|
||||||
|
|
||||||
used[cur_idx] = true;
|
used[cur_idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
++cur_idx;
|
++cur_idx;
|
||||||
|
|
||||||
|
@ -458,9 +467,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
||||||
return ubatch_add(idxs, idxs.size(), false);
|
return ubatch_add(idxs, idxs.size(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||||
|
if (sequential && has_cpl) {
|
||||||
|
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<seq_set_t> cur_seq_set;
|
std::vector<seq_set_t> cur_seq_set;
|
||||||
|
|
||||||
|
llama_seq_id last_seq_id = -1;
|
||||||
|
|
||||||
// determine the non-overlapping sequence sets participating in this ubatch
|
// determine the non-overlapping sequence sets participating in this ubatch
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (used[i]) {
|
if (used[i]) {
|
||||||
|
@ -477,9 +494,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// accept only increasing sequence ids
|
||||||
|
if (sequential) {
|
||||||
|
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||||
|
}
|
||||||
|
|
||||||
if (add) {
|
if (add) {
|
||||||
cur_seq_set.push_back(seq_set[i]);
|
cur_seq_set.push_back(seq_set[i]);
|
||||||
|
|
||||||
|
last_seq_id = batch.seq_id[i][0];
|
||||||
|
|
||||||
if (cur_seq_set.size() > n_ubatch) {
|
if (cur_seq_set.size() > n_ubatch) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -528,6 +552,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||||
idxs_per_seq[s].push_back(idx);
|
idxs_per_seq[s].push_back(idx);
|
||||||
|
|
||||||
used[idx] = true;
|
used[idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
++cur_idx[s];
|
++cur_idx[s];
|
||||||
}
|
}
|
||||||
|
@ -569,6 +594,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
||||||
idxs.push_back(cur_idx);
|
idxs.push_back(cur_idx);
|
||||||
|
|
||||||
used[cur_idx] = true;
|
used[cur_idx] = true;
|
||||||
|
++n_used;
|
||||||
|
|
||||||
if (idxs.size() >= n_ubatch) {
|
if (idxs.size() >= n_ubatch) {
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -54,6 +54,7 @@ public:
|
||||||
|
|
||||||
uint32_t get_n_tokens() const;
|
uint32_t get_n_tokens() const;
|
||||||
uint32_t get_n_outputs() const;
|
uint32_t get_n_outputs() const;
|
||||||
|
uint32_t get_n_used() const;
|
||||||
|
|
||||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||||
std::vector<int32_t> & get_out_ids();
|
std::vector<int32_t> & get_out_ids();
|
||||||
|
@ -69,7 +70,8 @@ public:
|
||||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||||
|
|
||||||
// make ubatches of equal-length sequences sets
|
// make ubatches of equal-length sequences sets
|
||||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||||
|
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||||
|
|
||||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||||
|
@ -112,6 +114,9 @@ private:
|
||||||
using pos_set_t = std::set<llama_pos>;
|
using pos_set_t = std::set<llama_pos>;
|
||||||
using seq_cpl_t = std::vector<bool>;
|
using seq_cpl_t = std::vector<bool>;
|
||||||
|
|
||||||
|
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||||
|
bool has_cpl;
|
||||||
|
|
||||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||||
|
|
||||||
|
@ -125,6 +130,8 @@ private:
|
||||||
// batch indices of the output
|
// batch indices of the output
|
||||||
std::vector<int32_t> out_ids;
|
std::vector<int32_t> out_ids;
|
||||||
|
|
||||||
|
uint32_t n_used;
|
||||||
|
|
||||||
// used[i] indicates if token i has already been used in a previous ubatch
|
// used[i] indicates if token i has already been used in a previous ubatch
|
||||||
std::vector<bool> used;
|
std::vector<bool> used;
|
||||||
|
|
||||||
|
|
|
@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
}
|
|
||||||
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
}
|
|
||||||
|
|
||||||
if (self_kq_mask_swa) {
|
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
||||||
}
|
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||||
|
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||||
|
|
||||||
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
|
@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
if (self_kq_mask) {
|
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||||
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||||
}
|
|
||||||
|
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
|
||||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
|
@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||||
|
GGML_UNUSED(ubatch);
|
||||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||||
float f_one = 1.0f;
|
float f_one = 1.0f;
|
||||||
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
||||||
|
@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
|
|
||||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
||||||
|
|
||||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||||
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
//cb(inp_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->kq_mask);
|
ggml_set_input(inp->kq_mask);
|
||||||
|
|
||||||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
||||||
|
@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & k_idxs = inp->get_k_idxs();
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
const auto & v_idxs = inp->get_v_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
// optionally store to KV cache
|
// optionally store to KV cache
|
||||||
if (k_cur) {
|
if (k_cur) {
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (v_cur) {
|
if (v_cur) {
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
|
@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
||||||
|
|
||||||
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||||
|
|
||||||
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->cross_kq_mask);
|
ggml_set_input(inp->cross_kq_mask);
|
||||||
|
|
||||||
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
||||||
|
@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & k_idxs = inp->get_k_idxs();
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
const auto & v_idxs = inp->get_v_idxs();
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
|
@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
uint32_t kv_head,
|
uint32_t kv_head,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
int32_t rs_zero,
|
int32_t rs_zero,
|
||||||
bool avoid_copies) const {
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||||
|
|
||||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
||||||
|
|
||||||
|
@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
||||||
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
||||||
|
|
||||||
ggml_tensor * output_states;
|
// copy states
|
||||||
|
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||||
if (!avoid_copies) {
|
// {state_size, kv_size} -> {state_size, n_seqs}
|
||||||
// copy states
|
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
||||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
ggml_build_forward_expand(gf, output_states);
|
||||||
// {state_size, kv_size} -> {state_size, n_seqs}
|
|
||||||
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
|
||||||
ggml_build_forward_expand(gf, output_states);
|
|
||||||
} else {
|
|
||||||
// FIXME: make the gathering operation happen before the copy below
|
|
||||||
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
|
||||||
output_states = states;
|
|
||||||
}
|
|
||||||
|
|
||||||
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
||||||
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
||||||
|
@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rs(
|
ggml_tensor * llm_graph_context::build_rs(
|
||||||
|
@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies) const {
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||||
|
|
||||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||||
|
|
|
@ -228,8 +228,8 @@ public:
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
||||||
|
|
||||||
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
||||||
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
@ -249,10 +249,16 @@ public:
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
@ -274,13 +280,23 @@ public:
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||||
|
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||||
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||||
|
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
@ -297,8 +313,8 @@ public:
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
||||||
|
|
||||||
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||||
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
||||||
|
|
||||||
const llama_cross * cross = nullptr;
|
const llama_cross * cross = nullptr;
|
||||||
};
|
};
|
||||||
|
@ -319,10 +335,16 @@ public:
|
||||||
|
|
||||||
ggml_tensor * s_copy; // I32 [kv_size]
|
ggml_tensor * s_copy; // I32 [kv_size]
|
||||||
|
|
||||||
|
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||||
|
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||||
|
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||||
|
|
||||||
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||||
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||||
|
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
const llama_cparams & cparams;
|
const llama_cparams & cparams;
|
||||||
|
@ -336,7 +358,7 @@ public:
|
||||||
llm_graph_input_one() {}
|
llm_graph_input_one() {}
|
||||||
virtual ~llm_graph_input_one() = default;
|
virtual ~llm_graph_input_one() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch *) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
ggml_tensor * one = nullptr; // F32
|
ggml_tensor * one = nullptr; // F32
|
||||||
};
|
};
|
||||||
|
@ -424,6 +446,9 @@ struct llm_graph_params {
|
||||||
const llm_graph_cb & cb;
|
const llm_graph_cb & cb;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// used in build_rs to properly order writes and avoid unnecessary copies
|
||||||
|
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
||||||
|
|
||||||
struct llm_graph_context {
|
struct llm_graph_context {
|
||||||
const llm_arch arch;
|
const llm_arch arch;
|
||||||
|
|
||||||
|
@ -663,7 +688,7 @@ struct llm_graph_context {
|
||||||
uint32_t kv_head,
|
uint32_t kv_head,
|
||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
int32_t rs_zero,
|
int32_t rs_zero,
|
||||||
bool avoid_copies = false) const;
|
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||||
|
|
||||||
llm_graph_input_rs * build_rs_inp() const;
|
llm_graph_input_rs * build_rs_inp() const;
|
||||||
|
|
||||||
|
@ -673,7 +698,7 @@ struct llm_graph_context {
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies = false) const;
|
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||||
|
|
||||||
ggml_tensor * build_rs(
|
ggml_tensor * build_rs(
|
||||||
llm_graph_input_mem_hybrid * inp,
|
llm_graph_input_mem_hybrid * inp,
|
||||||
|
@ -681,7 +706,7 @@ struct llm_graph_context {
|
||||||
ggml_tensor * s,
|
ggml_tensor * s,
|
||||||
int32_t state_size,
|
int32_t state_size,
|
||||||
int32_t n_seqs,
|
int32_t n_seqs,
|
||||||
bool avoid_copies = false) const;
|
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||||
|
|
||||||
ggml_tensor * build_rwkv_token_shift_load(
|
ggml_tensor * build_rwkv_token_shift_load(
|
||||||
llm_graph_input_rs * inp,
|
llm_graph_input_rs * inp,
|
||||||
|
|
|
@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
|
||||||
|
|
||||||
// TODO: maybe support other convolution strides than 1
|
// TODO: maybe support other convolution strides than 1
|
||||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
// Corresponds to Mamba's conv_states size
|
||||||
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_hparams::n_embd_s() const {
|
uint32_t llama_hparams::n_embd_s() const {
|
||||||
|
|
|
@ -114,6 +114,7 @@ struct llama_hparams {
|
||||||
uint32_t ssm_d_inner = 0;
|
uint32_t ssm_d_inner = 0;
|
||||||
uint32_t ssm_d_state = 0;
|
uint32_t ssm_d_state = 0;
|
||||||
uint32_t ssm_dt_rank = 0;
|
uint32_t ssm_dt_rank = 0;
|
||||||
|
uint32_t ssm_n_group = 0;
|
||||||
|
|
||||||
// for hybrid state space models
|
// for hybrid state space models
|
||||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||||
|
|
|
@ -117,20 +117,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
if (heads_base.empty()) {
|
// failed to find a suitable split
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
|
if (sinfos_swa.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// if it fails, try equal split
|
// if it fails, try equal split
|
||||||
|
@ -139,7 +144,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
while (true) {
|
while (true) {
|
||||||
auto ubatch = balloc.split_equal(n_ubatch);
|
auto ubatch = balloc.split_equal(n_ubatch, false);
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
|
@ -148,20 +153,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_base = kv_base->prepare(ubatches);
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
if (heads_base.empty()) {
|
// failed to find a suitable split
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads_swa = kv_swa->prepare(ubatches);
|
auto sinfos_base = kv_base->prepare(ubatches);
|
||||||
if (heads_swa.empty()) {
|
if (sinfos_base.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
auto sinfos_swa = kv_swa->prepare(ubatches);
|
||||||
|
if (sinfos_swa.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(sinfos_base.size() == sinfos_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
// TODO: if we fail again, we should attempt different splitting strategies
|
// TODO: if we fail again, we should attempt different splitting strategies
|
||||||
|
@ -224,13 +234,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,8 @@ private:
|
||||||
|
|
||||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
|
@ -90,8 +92,8 @@ public:
|
||||||
// used to create a batch processing context from a batch
|
// used to create a batch processing context from a batch
|
||||||
llama_kv_cache_unified_iswa_context(
|
llama_kv_cache_unified_iswa_context(
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
std::vector<uint32_t> heads_base,
|
slot_info_vec_t sinfos_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
slot_info_vec_t sinfos_swa,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_iswa_context();
|
virtual ~llama_kv_cache_unified_iswa_context();
|
||||||
|
|
|
@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||||
|
|
||||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||||
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
||||||
|
|
||||||
|
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||||
|
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||||
|
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::clear(bool data) {
|
void llama_kv_cache_unified::clear(bool data) {
|
||||||
|
@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
auto heads = prepare(ubatches);
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
if (heads.empty()) {
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sinfos = prepare(ubatches);
|
||||||
|
if (sinfos.empty()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(
|
return std::make_unique<llama_kv_cache_unified_context>(
|
||||||
this, std::move(heads), std::move(ubatches));
|
this, std::move(sinfos), std::move(ubatches));
|
||||||
} while (false);
|
} while (false);
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
||||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
llama_kv_cache_unified::ubatch_heads res;
|
llama_kv_cache_unified::slot_info_vec_t res;
|
||||||
|
|
||||||
struct state {
|
struct state {
|
||||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||||
uint32_t head_new; // new position of the head, after placing the ubatch
|
|
||||||
|
slot_info sinfo; // slot info for the ubatch
|
||||||
|
|
||||||
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
||||||
};
|
};
|
||||||
|
@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
|
||||||
bool success = true;
|
bool success = true;
|
||||||
|
|
||||||
for (const auto & ubatch : ubatches) {
|
for (const auto & ubatch : ubatches) {
|
||||||
|
// non-continuous slots require support for ggml_set_rows()
|
||||||
|
const bool cont = supports_set_rows ? false : true;
|
||||||
|
|
||||||
// only find a suitable slot for the ubatch. don't modify the cells yet
|
// only find a suitable slot for the ubatch. don't modify the cells yet
|
||||||
const int32_t head_new = find_slot(ubatch);
|
const auto sinfo_new = find_slot(ubatch, cont);
|
||||||
if (head_new < 0) {
|
if (sinfo_new.empty()) {
|
||||||
success = false;
|
success = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// remeber the position that we found
|
// remeber the position that we found
|
||||||
res.push_back(head_new);
|
res.push_back(sinfo_new);
|
||||||
|
|
||||||
// store the old state of the cells in the recovery stack
|
// store the old state of the cells in the recovery stack
|
||||||
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
|
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
||||||
|
|
||||||
// now emplace the ubatch
|
// now emplace the ubatch
|
||||||
apply_ubatch(head_new, ubatch);
|
apply_ubatch(sinfo_new, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate backwards and restore the cells to their original state
|
// iterate backwards and restore the cells to their original state
|
||||||
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
||||||
cells.set(it->head_new, it->cells);
|
cells.set(it->sinfo.idxs, it->cells);
|
||||||
head = it->head_old;
|
head = it->head_old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
uint32_t head_cur = this->head;
|
uint32_t head_cur = this->head;
|
||||||
|
@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
|
|
||||||
if (n_tokens > cells.size()) {
|
if (n_tokens > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (debug > 0) {
|
if (debug > 0) {
|
||||||
|
@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
|
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
|
||||||
|
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
||||||
|
// for non-continuous slots, we test the tokens one by one
|
||||||
|
const uint32_t n_test = cont ? n_tokens : 1;
|
||||||
|
|
||||||
|
slot_info res;
|
||||||
|
|
||||||
|
auto & idxs = res.idxs;
|
||||||
|
|
||||||
|
idxs.reserve(n_tokens);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (head_cur + n_tokens > cells.size()) {
|
if (head_cur + n_test > cells.size()) {
|
||||||
n_tested += cells.size() - head_cur;
|
n_tested += cells.size() - head_cur;
|
||||||
head_cur = 0;
|
head_cur = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool found = true;
|
for (uint32_t i = 0; i < n_test; i++) {
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
const auto idx = head_cur;
|
||||||
|
|
||||||
//const llama_pos pos = ubatch.pos[i];
|
//const llama_pos pos = ubatch.pos[i];
|
||||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
|
|
||||||
|
@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||||
// - mask SWA, using current max pos for that sequence in the cache
|
// - mask SWA, using current max pos for that sequence in the cache
|
||||||
// always insert in the cell with minimum pos
|
// always insert in the cell with minimum pos
|
||||||
bool can_use = cells.is_empty(head_cur + i);
|
bool can_use = cells.is_empty(idx);
|
||||||
|
|
||||||
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
if (!can_use && cells.seq_count(idx) == 1) {
|
||||||
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
const llama_pos pos_cell = cells.pos_get(idx);
|
||||||
|
|
||||||
// (disabled) causal mask
|
// (disabled) causal mask
|
||||||
// note: it's better to purge any "future" tokens beforehand
|
// note: it's better to purge any "future" tokens beforehand
|
||||||
//if (cells.seq_has(head_cur + i, seq_id)) {
|
//if (cells.seq_has(idx, seq_id)) {
|
||||||
// can_use = pos_cell >= pos;
|
// can_use = pos_cell >= pos;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (!can_use) {
|
if (!can_use) {
|
||||||
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||||
|
|
||||||
// SWA mask
|
// SWA mask
|
||||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||||
|
@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!can_use) {
|
head_cur++;
|
||||||
found = false;
|
n_tested++;
|
||||||
head_cur += i + 1;
|
|
||||||
n_tested += i + 1;
|
if (can_use) {
|
||||||
|
idxs.push_back(idx);
|
||||||
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (found) {
|
if (idxs.size() == n_tokens) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cont) {
|
||||||
|
idxs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
if (n_tested >= cells.size()) {
|
if (n_tested >= cells.size()) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return -1;
|
return { };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return head_cur;
|
// we didn't find a suitable slot - return empty result
|
||||||
|
if (idxs.size() < n_tokens) {
|
||||||
|
res.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
// for non-SWA cache, this would be always empty
|
// for non-SWA cache, this would be always empty
|
||||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||||
|
@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||||
seq_pos_max_rm[s] = -1;
|
seq_pos_max_rm[s] = -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
assert(ubatch.n_tokens == sinfo.idxs.size());
|
||||||
if (!cells.is_empty(head_cur + i)) {
|
|
||||||
assert(cells.seq_count(head_cur + i) == 1);
|
|
||||||
|
|
||||||
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
const llama_pos pos = cells.pos_get(head_cur + i);
|
const auto idx = sinfo.idxs.at(i);
|
||||||
|
|
||||||
|
if (!cells.is_empty(idx)) {
|
||||||
|
assert(cells.seq_count(idx) == 1);
|
||||||
|
|
||||||
|
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||||
|
const llama_pos pos = cells.pos_get(idx);
|
||||||
|
|
||||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||||
|
|
||||||
cells.rm(head_cur + i);
|
cells.rm(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
cells.pos_set(idx, ubatch.pos[i]);
|
||||||
|
|
||||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||||
}
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
head = head_cur + ubatch.n_tokens;
|
head = sinfo.idxs.back() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::get_can_shift() const {
|
bool llama_kv_cache_unified::get_can_shift() const {
|
||||||
|
@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * k = layers[ikv].k;
|
auto * k = layers[ikv].k;
|
||||||
|
|
||||||
|
const int64_t n_embd_k_gqa = k->ne[0];
|
||||||
const int64_t n_tokens = k_cur->ne[2];
|
const int64_t n_tokens = k_cur->ne[2];
|
||||||
|
|
||||||
|
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||||
|
|
||||||
|
if (k_idxs && supports_set_rows) {
|
||||||
|
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||||
n_tokens*hparams.n_embd_k_gqa(il),
|
n_tokens*n_embd_k_gqa,
|
||||||
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
||||||
|
|
||||||
return ggml_cpy(ctx, k_cur, k_view);
|
return ggml_cpy(ctx, k_cur, k_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * v = layers[ikv].v;
|
auto * v = layers[ikv].v;
|
||||||
|
|
||||||
|
const int64_t n_embd_v_gqa = v->ne[0];
|
||||||
const int64_t n_tokens = v_cur->ne[2];
|
const int64_t n_tokens = v_cur->ne[2];
|
||||||
|
|
||||||
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
|
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||||
|
|
||||||
|
if (v_idxs && supports_set_rows) {
|
||||||
|
if (!v_trans) {
|
||||||
|
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// the row becomes a single element
|
||||||
|
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
||||||
|
|
||||||
|
// note: the V cache is transposed when not using flash attention
|
||||||
|
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
||||||
|
|
||||||
|
// note: we can be more explicit here at the cost of extra cont
|
||||||
|
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
||||||
|
//v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||||
|
|
||||||
|
// we broadcast the KV indices n_embd_v_gqa times
|
||||||
|
// v [1, n_kv, n_embd_v_gqa]
|
||||||
|
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||||
|
// v_idxs [n_tokens, 1, 1]
|
||||||
|
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
|
// will be removed when ggml_set_rows() is adopted by all backends
|
||||||
|
|
||||||
ggml_tensor * v_view = nullptr;
|
ggml_tensor * v_view = nullptr;
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
v_view = ggml_view_1d(ctx, v,
|
v_view = ggml_view_1d(ctx, v,
|
||||||
n_tokens*hparams.n_embd_v_gqa(il),
|
n_tokens*n_embd_v_gqa,
|
||||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
|
||||||
} else {
|
} else {
|
||||||
// note: the V cache is transposed when not using flash attention
|
|
||||||
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
|
||||||
(v->ne[1])*ggml_element_size(v),
|
|
||||||
(head_cur)*ggml_element_size(v));
|
|
||||||
|
|
||||||
v_cur = ggml_transpose(ctx, v_cur);
|
v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
|
||||||
|
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
||||||
|
(v->ne[1] )*ggml_element_size(v),
|
||||||
|
(sinfo.head())*ggml_element_size(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_cpy(ctx, v_cur, v_view);
|
return ggml_cpy(ctx, v_cur, v_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
|
||||||
|
ggml_set_input(k_idxs);
|
||||||
|
|
||||||
|
return k_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
|
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||||
|
|
||||||
|
ggml_set_input(v_idxs);
|
||||||
|
|
||||||
|
return v_idxs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = sinfo.idxs.at(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||||
|
if (!supports_set_rows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = sinfo.idxs.at(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||||
ubatch.seq_id[i] = &dest_seq_id;
|
ubatch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto head_cur = find_slot(ubatch);
|
const auto sinfo = find_slot(ubatch, true);
|
||||||
if (head_cur < 0) {
|
if (sinfo.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
apply_ubatch(head_cur, ubatch);
|
apply_ubatch(sinfo, ubatch);
|
||||||
|
|
||||||
|
const auto head_cur = sinfo.head();
|
||||||
|
|
||||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||||
head = head_cur;
|
head = head_cur;
|
||||||
|
@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
n_kv = kv->get_size();
|
n_kv = kv->get_size();
|
||||||
head = 0;
|
|
||||||
|
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
||||||
|
sinfos.resize(1);
|
||||||
|
sinfos[0].idxs.resize(1);
|
||||||
|
sinfos[0].idxs[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
|
@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
|
|
||||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified::ubatch_heads heads,
|
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
|
@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||||
bool llama_kv_cache_unified_context::next() {
|
bool llama_kv_cache_unified_context::next() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
if (++i_next >= ubatches.size()) {
|
if (++i_cur >= ubatches.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
||||||
|
|
||||||
n_kv = kv->get_n_kv();
|
n_kv = kv->get_n_kv();
|
||||||
head = heads[i_next];
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_cur];
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||||
|
@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
||||||
return kv->get_v(ctx, il, n_kv);
|
return kv->get_v(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, il, head);
|
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, il, head);
|
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_k_idxs(ctx, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||||
|
return kv->build_input_v_idxs(ctx, ubatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,6 @@ public:
|
||||||
// this callback is used to filter out layers that should not be included in the cache
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
using ubatch_heads = std::vector<uint32_t>;
|
|
||||||
|
|
||||||
struct defrag_info {
|
struct defrag_info {
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return ids.empty();
|
return ids.empty();
|
||||||
|
@ -37,6 +35,32 @@ public:
|
||||||
std::vector<uint32_t> ids;
|
std::vector<uint32_t> ids;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||||
|
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||||
|
struct slot_info {
|
||||||
|
// data for ggml_set_rows
|
||||||
|
using idx_vec_t = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
idx_vec_t idxs;
|
||||||
|
|
||||||
|
uint32_t head() const {
|
||||||
|
return idxs.at(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return idxs.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
idxs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
//std::vector<idx_vec_t> seq_idxs;
|
||||||
|
};
|
||||||
|
|
||||||
|
using slot_info_vec_t = std::vector<slot_info>;
|
||||||
|
|
||||||
llama_kv_cache_unified(
|
llama_kv_cache_unified(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
layer_filter_cb && filter,
|
layer_filter_cb && filter,
|
||||||
|
@ -102,30 +126,37 @@ public:
|
||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// preparation API
|
// preparation API
|
||||||
//
|
//
|
||||||
|
|
||||||
// find places for the provided ubatches in the cache, returns the head locations
|
// find places for the provided ubatches in the cache, returns the slot infos
|
||||||
// return empty vector on failure
|
// return empty vector on failure
|
||||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
// return the cell position where we can insert the ubatch
|
// find a slot of kv cells that can hold the ubatch
|
||||||
// return -1 on failure to find a contiguous slot of kv cells
|
// if cont == true, then the slot must be continuous
|
||||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
// return empty slot_info on failure
|
||||||
|
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||||
|
|
||||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||||
|
|
||||||
//
|
//
|
||||||
// set_input API
|
// input API
|
||||||
//
|
//
|
||||||
|
|
||||||
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||||
|
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
@ -157,8 +188,13 @@ private:
|
||||||
// SWA
|
// SWA
|
||||||
const uint32_t n_swa = 0;
|
const uint32_t n_swa = 0;
|
||||||
|
|
||||||
|
// env: LLAMA_KV_CACHE_DEBUG
|
||||||
int debug = 0;
|
int debug = 0;
|
||||||
|
|
||||||
|
// env: LLAMA_SET_ROWS (temporary)
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
||||||
|
int supports_set_rows = false;
|
||||||
|
|
||||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||||
|
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
@ -211,8 +247,8 @@ private:
|
||||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
// some shorthands
|
// some shorthands
|
||||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_context(llama_memory_status status);
|
llama_kv_cache_unified_context(llama_memory_status status);
|
||||||
|
@ -231,7 +267,7 @@ public:
|
||||||
// used to create a batch procesing context from a batch
|
// used to create a batch procesing context from a batch
|
||||||
llama_kv_cache_unified_context(
|
llama_kv_cache_unified_context(
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
ubatch_heads heads,
|
slot_info_vec_t sinfos,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_context();
|
virtual ~llama_kv_cache_unified_context();
|
||||||
|
@ -257,11 +293,16 @@ public:
|
||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||||
|
|
||||||
void set_input_k_shift(ggml_tensor * dst) const;
|
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
@ -283,10 +324,10 @@ private:
|
||||||
// batch processing context
|
// batch processing context
|
||||||
//
|
//
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the cur ubatch to process
|
||||||
size_t i_next = 0;
|
size_t i_cur = 0;
|
||||||
|
|
||||||
ubatch_heads heads;
|
slot_info_vec_t sinfos;
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
@ -297,7 +338,4 @@ private:
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// as the cache gets filled, the benefit from this heuristic disappears
|
// as the cache gets filled, the benefit from this heuristic disappears
|
||||||
int32_t n_kv;
|
int32_t n_kv;
|
||||||
|
|
||||||
// the beginning of the current slot in which the ubatch will be inserted
|
|
||||||
int32_t head;
|
|
||||||
};
|
};
|
||||||
|
|
|
@ -105,10 +105,30 @@ public:
|
||||||
res.resize(n);
|
res.resize(n);
|
||||||
|
|
||||||
for (uint32_t j = 0; j < n; ++j) {
|
for (uint32_t j = 0; j < n; ++j) {
|
||||||
res.pos[j] = pos[i + j];
|
const auto idx = i + j;
|
||||||
res.seq[j] = seq[i + j];
|
|
||||||
|
|
||||||
assert(shift[i + j] == 0);
|
res.pos[j] = pos[idx];
|
||||||
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||||
|
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
||||||
|
llama_kv_cells_unified res;
|
||||||
|
|
||||||
|
res.resize(idxs.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < idxs.size(); ++j) {
|
||||||
|
const auto idx = idxs[j];
|
||||||
|
|
||||||
|
res.pos[j] = pos[idx];
|
||||||
|
res.seq[j] = seq[idx];
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -119,26 +139,58 @@ public:
|
||||||
assert(i + other.pos.size() <= pos.size());
|
assert(i + other.pos.size() <= pos.size());
|
||||||
|
|
||||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
const auto idx = i + j;
|
||||||
|
|
||||||
|
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||||
used.insert(i + j);
|
used.insert(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||||
used.erase(i + j);
|
used.erase(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pos[i + j] != -1) {
|
if (pos[idx] != -1) {
|
||||||
seq_pos_rm(i + j);
|
seq_pos_rm(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
pos[i + j] = other.pos[j];
|
pos[idx] = other.pos[j];
|
||||||
seq[i + j] = other.seq[j];
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
if (pos[i + j] != -1) {
|
if (pos[idx] != -1) {
|
||||||
seq_pos_add(i + j);
|
seq_pos_add(i + j);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(shift[i + j] == 0);
|
assert(shift[idx] == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
||||||
|
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
||||||
|
assert(idxs.size() == other.pos.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
|
const auto idx = idxs[j];
|
||||||
|
|
||||||
|
if (pos[idx] == -1 && other.pos[j] != -1) {
|
||||||
|
used.insert(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[idx] != -1 && other.pos[j] == -1) {
|
||||||
|
used.erase(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[idx] != -1) {
|
||||||
|
seq_pos_rm(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pos[idx] = other.pos[j];
|
||||||
|
seq[idx] = other.seq[j];
|
||||||
|
|
||||||
|
if (pos[idx] != -1) {
|
||||||
|
seq_pos_add(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(shift[idx] == 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||||
// if all tokens are output, split by sequence
|
// if all tokens are output, split by sequence
|
||||||
ubatch = balloc.split_seq(n_ubatch);
|
ubatch = balloc.split_seq(n_ubatch);
|
||||||
} else {
|
} else {
|
||||||
ubatch = balloc.split_equal(n_ubatch);
|
ubatch = balloc.split_equal(n_ubatch, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (ubatch.n_tokens == 0) {
|
||||||
|
@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
||||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// prepare the recurrent batches first
|
// prepare the recurrent batches first
|
||||||
if (!mem_recr->prepare(ubatches)) {
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||||
|
@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
|
|
||||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches) :
|
std::vector<llama_ubatch> ubatches) :
|
||||||
ubatches(std::move(ubatches)),
|
ubatches(std::move(ubatches)),
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
||||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,8 @@ private:
|
||||||
|
|
||||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||||
public:
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||||
|
|
||||||
// init failure
|
// init failure
|
||||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||||
|
|
||||||
|
@ -107,7 +109,7 @@ public:
|
||||||
// init success
|
// init success
|
||||||
llama_memory_hybrid_context(
|
llama_memory_hybrid_context(
|
||||||
llama_memory_hybrid * mem,
|
llama_memory_hybrid * mem,
|
||||||
std::vector<uint32_t> heads_attn,
|
slot_info_vec_t sinfos_attn,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
~llama_memory_hybrid_context() = default;
|
~llama_memory_hybrid_context() = default;
|
||||||
|
|
|
@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
||||||
// if all tokens are output, split by sequence
|
// if all tokens are output, split by sequence
|
||||||
ubatch = balloc.split_seq(n_ubatch);
|
ubatch = balloc.split_seq(n_ubatch);
|
||||||
} else {
|
} else {
|
||||||
ubatch = balloc.split_equal(n_ubatch);
|
ubatch = balloc.split_equal(n_ubatch, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ubatch.n_tokens == 0) {
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -213,23 +213,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
{
|
{
|
||||||
// FIXME
|
const int64_t n_seq_tokens = 512;
|
||||||
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
|
const int64_t n_seqs = 3;
|
||||||
|
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs);
|
||||||
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
|
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
// FIXME
|
// w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2
|
||||||
const int64_t d_state = w->ne[0];
|
const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
|
||||||
const int64_t d_inner = w->ne[1];
|
const int64_t n_head = w->ne[1];
|
||||||
|
const int64_t head_dim = hparams.ssm_d_inner / n_head;
|
||||||
|
const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
|
||||||
const int64_t n_seq_tokens = 512;
|
const int64_t n_seq_tokens = 512;
|
||||||
const int64_t n_seqs = 1;
|
const int64_t n_seqs = 3;
|
||||||
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
|
ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);
|
||||||
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
|
ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||||
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
|
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs);
|
||||||
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
|
||||||
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
|
||||||
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
|
||||||
|
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
{
|
{
|
||||||
|
@ -1086,6 +1090,38 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: type = LLM_TYPE_UNKNOWN;
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_MAMBA2:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||||
|
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||||
|
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||||
|
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||||
|
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 24:
|
||||||
|
switch (hparams.n_embd) {
|
||||||
|
case 768: type = LLM_TYPE_SMALL; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
} break;
|
||||||
|
case 48:
|
||||||
|
switch (hparams.n_embd) {
|
||||||
|
case 1024: type = LLM_TYPE_MEDIUM; break;
|
||||||
|
case 1536: type = LLM_TYPE_LARGE; break;
|
||||||
|
case 2048: type = LLM_TYPE_XL; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
} break;
|
||||||
|
case 64:
|
||||||
|
switch (hparams.n_embd) {
|
||||||
|
case 2560: type = LLM_TYPE_3B; break;
|
||||||
|
case 4096: type = LLM_TYPE_7B; break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
} break;
|
||||||
|
default: type = LLM_TYPE_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_XVERSE:
|
case LLM_ARCH_XVERSE:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
@ -3216,6 +3252,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
|
||||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
|
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
|
||||||
|
|
||||||
|
// out_proj
|
||||||
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case LLM_ARCH_MAMBA2:
|
||||||
|
{
|
||||||
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
|
const int64_t n_head = hparams.ssm_dt_rank;
|
||||||
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
|
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
|
||||||
|
|
||||||
|
// only an expansion factor of 2 is supported for now
|
||||||
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
|
||||||
|
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
// if output is NULL, init from the input tok embed, duplicated to allow offloading
|
||||||
|
if (output == NULL) {
|
||||||
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
// norm
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
|
||||||
|
|
||||||
|
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||||
|
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
|
||||||
|
|
||||||
|
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
|
||||||
|
|
||||||
|
// no "weight" suffix for these
|
||||||
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
|
||||||
|
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
|
||||||
|
|
||||||
|
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||||
|
|
||||||
// out_proj
|
// out_proj
|
||||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||||
}
|
}
|
||||||
|
@ -4727,10 +4811,14 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||||
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
|
||||||
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
||||||
|
|
||||||
if (!classifier_labels.empty()) {
|
if (!classifier_labels.empty()) {
|
||||||
|
@ -9765,9 +9853,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_build_mamba : public llm_graph_context {
|
struct llm_build_mamba : public llm_graph_context {
|
||||||
const llama_model & model;
|
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||||
|
|
||||||
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
||||||
|
@ -9785,7 +9871,11 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
|
if (model.arch == LLM_ARCH_MAMBA2) {
|
||||||
|
cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
|
||||||
|
} else {
|
||||||
|
cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
|
||||||
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1 && inp_out_ids) {
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
@ -9819,11 +9909,11 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: split
|
|
||||||
ggml_tensor * build_mamba_layer(
|
ggml_tensor * build_mamba_layer(
|
||||||
llm_graph_input_rs * inp,
|
llm_graph_input_rs * inp,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
|
const llama_model & model,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const {
|
int il) const {
|
||||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
@ -9834,6 +9924,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
const int64_t d_inner = hparams.ssm_d_inner;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
const int64_t d_state = hparams.ssm_d_state;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||||
|
const int64_t n_head = d_inner;
|
||||||
|
const int64_t head_dim = 1;
|
||||||
const int64_t n_seqs = ubatch.n_seqs;
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
||||||
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
||||||
|
@ -9849,15 +9941,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||||
|
|
||||||
// (ab)using the KV cache to store the states
|
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||||
ggml_tensor * conv = build_rs(
|
|
||||||
inp, gf, conv_states_all,
|
|
||||||
hparams.n_embd_r(), n_seqs);
|
|
||||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
||||||
ggml_tensor * ssm = build_rs(
|
|
||||||
inp, gf, ssm_states_all,
|
|
||||||
hparams.n_embd_s(), n_seqs);
|
|
||||||
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
|
||||||
|
|
||||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||||
|
@ -9906,8 +9991,8 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
|
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
|
||||||
// split
|
// split
|
||||||
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
|
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
|
||||||
ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
|
ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
|
||||||
ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
|
ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||||
|
|
||||||
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
||||||
if (ssm_dt_b_c_rms) {
|
if (ssm_dt_b_c_rms) {
|
||||||
|
@ -9920,23 +10005,36 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
|
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
|
||||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||||
|
|
||||||
// Custom operator to optimize the parallel associative scan
|
cur = x;
|
||||||
// as described in the Annex D of the Mamba paper.
|
x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||||
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
|
||||||
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C);
|
ggml_tensor * A = model.layers[il].ssm_a;
|
||||||
|
|
||||||
|
// use the states and the indices provided by build_recurrent_state
|
||||||
|
// (this is necessary in order to properly use the states before they are overwritten,
|
||||||
|
// while avoiding to make unnecessary copies of the states)
|
||||||
|
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
|
||||||
|
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
|
||||||
|
|
||||||
|
// Custom operator to optimize the parallel associative scan
|
||||||
|
// as described in the Annex D of the Mamba paper.
|
||||||
|
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||||
|
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
|
||||||
|
|
||||||
// store last states
|
// store last states
|
||||||
ggml_build_forward_expand(gf,
|
ggml_build_forward_expand(gf,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
|
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
|
||||||
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
||||||
|
|
||||||
ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
|
ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
|
||||||
|
|
||||||
// TODO: skip computing output earlier for unused tokens
|
// TODO: skip computing output earlier for unused tokens
|
||||||
|
|
||||||
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
|
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
|
||||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
|
||||||
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
|
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
|
||||||
|
|
||||||
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
|
@ -9945,7 +10043,136 @@ struct llm_build_mamba : public llm_graph_context {
|
||||||
|
|
||||||
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
||||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
||||||
//cb(cur, "mamba_out", il);
|
// cb(cur, "mamba_out", il);
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_mamba2_layer(
|
||||||
|
llm_graph_input_rs * inp,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
const llama_model & model,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il) const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
|
const auto kv_head = mctx_cur->get_head();
|
||||||
|
|
||||||
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
|
const int64_t n_head = hparams.ssm_dt_rank;
|
||||||
|
const int64_t head_dim = d_inner / n_head;
|
||||||
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
|
const int64_t n_seqs = ubatch.n_seqs;
|
||||||
|
|
||||||
|
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(n_seqs != 0);
|
||||||
|
GGML_ASSERT(ubatch.equal_seqs);
|
||||||
|
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||||
|
|
||||||
|
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||||
|
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||||
|
|
||||||
|
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||||
|
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
||||||
|
|
||||||
|
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
|
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
|
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
||||||
|
|
||||||
|
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
|
||||||
|
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
|
||||||
|
|
||||||
|
// split the above in three
|
||||||
|
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
|
||||||
|
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
|
||||||
|
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
|
||||||
|
|
||||||
|
// conv
|
||||||
|
{
|
||||||
|
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
|
||||||
|
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
|
||||||
|
|
||||||
|
// copy last (d_conv - 1) columns back into the state cache
|
||||||
|
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf,
|
||||||
|
ggml_cpy(ctx0, last_conv,
|
||||||
|
ggml_view_1d(ctx0, conv_states_all,
|
||||||
|
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
|
||||||
|
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
|
||||||
|
|
||||||
|
// 1D convolution
|
||||||
|
// The equivalent is to make a self-overlapping view of conv_x
|
||||||
|
// over d_conv columns at each stride in the 3rd dimension,
|
||||||
|
// then element-wise multiply that with the conv1d weight,
|
||||||
|
// then sum the elements of each row,
|
||||||
|
// (the last two steps are a dot product over rows (also doable with mul_mat))
|
||||||
|
// then permute away the ne[0] dimension,
|
||||||
|
// and then you're left with the resulting x tensor.
|
||||||
|
// For simultaneous sequences, all sequences need to have the same length.
|
||||||
|
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
|
||||||
|
|
||||||
|
xBC = ggml_silu(ctx0, xBC);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ssm
|
||||||
|
{
|
||||||
|
// These correspond to V K Q in SSM/attention duality
|
||||||
|
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
|
||||||
|
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
|
||||||
|
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
|
||||||
|
|
||||||
|
// {n_head, n_seq_tokens, n_seqs}
|
||||||
|
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
|
||||||
|
|
||||||
|
ggml_tensor * A = model.layers[il].ssm_a;
|
||||||
|
|
||||||
|
// use the states and the indices provided by build_recurrent_state
|
||||||
|
// (this is necessary in order to properly use the states before they are overwritten,
|
||||||
|
// while avoiding to make unnecessary copies of the states)
|
||||||
|
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
|
||||||
|
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
|
||||||
|
|
||||||
|
// TODO: use semistructured matrices to implement state-space duality
|
||||||
|
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
|
||||||
|
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
|
||||||
|
|
||||||
|
// store last states
|
||||||
|
ggml_build_forward_expand(gf,
|
||||||
|
ggml_cpy(ctx0,
|
||||||
|
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
|
||||||
|
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
|
||||||
|
|
||||||
|
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
|
||||||
|
|
||||||
|
// TODO: skip computing output earlier for unused tokens
|
||||||
|
|
||||||
|
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||||
|
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
|
||||||
|
|
||||||
|
// grouped RMS norm
|
||||||
|
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
|
||||||
|
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
|
||||||
|
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
|
||||||
|
|
||||||
|
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
|
cur = build_lora_mm(model.layers[il].ssm_out, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
|
||||||
|
// cb(cur, "mamba_out", il);
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
@ -14768,6 +14995,7 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||||
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
|
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
|
case LLM_ARCH_MAMBA2:
|
||||||
{
|
{
|
||||||
llm = std::make_unique<llm_build_mamba>(*this, params, gf);
|
llm = std::make_unique<llm_build_mamba>(*this, params, gf);
|
||||||
} break;
|
} break;
|
||||||
|
@ -15028,6 +15256,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||||
case LLM_ARCH_REFACT:
|
case LLM_ARCH_REFACT:
|
||||||
case LLM_ARCH_BLOOM:
|
case LLM_ARCH_BLOOM:
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
|
case LLM_ARCH_MAMBA2:
|
||||||
case LLM_ARCH_JINA_BERT_V2:
|
case LLM_ARCH_JINA_BERT_V2:
|
||||||
case LLM_ARCH_T5:
|
case LLM_ARCH_T5:
|
||||||
case LLM_ARCH_T5ENCODER:
|
case LLM_ARCH_T5ENCODER:
|
||||||
|
|
|
@ -172,6 +172,7 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ffn_sub_norm = nullptr;
|
struct ggml_tensor * ffn_sub_norm = nullptr;
|
||||||
struct ggml_tensor * attn_norm_cross = nullptr;
|
struct ggml_tensor * attn_norm_cross = nullptr;
|
||||||
struct ggml_tensor * attn_norm_enc = nullptr;
|
struct ggml_tensor * attn_norm_enc = nullptr;
|
||||||
|
struct ggml_tensor * ssm_norm = nullptr;
|
||||||
|
|
||||||
// attention
|
// attention
|
||||||
struct ggml_tensor * wq = nullptr;
|
struct ggml_tensor * wq = nullptr;
|
||||||
|
|
|
@ -1430,8 +1430,7 @@ struct clip_graph {
|
||||||
ggml_tensor * x = embeddings;
|
ggml_tensor * x = embeddings;
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
|
||||||
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
|
||||||
embeddings = ggml_silu_inplace(ctx0, embeddings);
|
embeddings = ggml_swiglu_split(ctx0, embeddings, x);
|
||||||
embeddings = ggml_mul(ctx0, embeddings,x);
|
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
|
||||||
}
|
}
|
||||||
// arrangement of BOI/EOI token embeddings
|
// arrangement of BOI/EOI token embeddings
|
||||||
|
@ -1527,15 +1526,8 @@ struct clip_graph {
|
||||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||||
|
|
||||||
// swiglu
|
// swiglu
|
||||||
{
|
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||||
int64_t split_point = cur->ne[0] / 2;
|
cur = ggml_swiglu_swapped(ctx0, cur);
|
||||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
||||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
||||||
|
|
||||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
|
||||||
x1 = ggml_silu(ctx0, x1);
|
|
||||||
cur = ggml_mul(ctx0, x0, x1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// mid-norm
|
// mid-norm
|
||||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||||
|
@ -1794,35 +1786,42 @@ private:
|
||||||
cur = tmp;
|
cur = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we only support parallel ffn for now
|
||||||
switch (type_op) {
|
switch (type_op) {
|
||||||
case FFN_SILU:
|
case FFN_SILU:
|
||||||
{
|
if (gate) {
|
||||||
|
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_swiglu", il);
|
||||||
|
} else {
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
cb(cur, "ffn_silu", il);
|
cb(cur, "ffn_silu", il);
|
||||||
} break;
|
} break;
|
||||||
case FFN_GELU:
|
case FFN_GELU:
|
||||||
{
|
if (gate) {
|
||||||
|
cur = ggml_geglu_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_geglu", il);
|
||||||
|
} else {
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cb(cur, "ffn_gelu", il);
|
cb(cur, "ffn_gelu", il);
|
||||||
} break;
|
} break;
|
||||||
case FFN_GELU_ERF:
|
case FFN_GELU_ERF:
|
||||||
{
|
if (gate) {
|
||||||
|
cur = ggml_geglu_erf_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_geglu_erf", il);
|
||||||
|
} else {
|
||||||
cur = ggml_gelu_erf(ctx0, cur);
|
cur = ggml_gelu_erf(ctx0, cur);
|
||||||
cb(cur, "ggml_gelu_erf", il);
|
cb(cur, "ffn_gelu_erf", il);
|
||||||
} break;
|
} break;
|
||||||
case FFN_GELU_QUICK:
|
case FFN_GELU_QUICK:
|
||||||
{
|
if (gate) {
|
||||||
|
cur = ggml_geglu_quick_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_geglu_quick", il);
|
||||||
|
} else {
|
||||||
cur = ggml_gelu_quick(ctx0, cur);
|
cur = ggml_gelu_quick(ctx0, cur);
|
||||||
cb(cur, "ffn_relu", il);
|
cb(cur, "ffn_gelu_quick", il);
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// we only support parallel ffn for now
|
|
||||||
if (gate) {
|
|
||||||
cur = ggml_mul(ctx0, cur, tmp);
|
|
||||||
cb(cur, "ffn_gate_par", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (down) {
|
if (down) {
|
||||||
cur = ggml_mul_mat(ctx0, down, cur);
|
cur = ggml_mul_mat(ctx0, down, cur);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue