llama : initial Mamba-2 support (#9126)

* llama : initial Mamba-2 support

* ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states

* llama : support running Mamba-Codestral-7B-v0.1

* llama : fix Mamba-2 conv state saving

* ggml : make the ggml_mul fast broadcast path more consistently formatted

* llama : remove unused variable

* llama : add missing break

* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.

* llama : avoid redundant state copy for Mamba 1 and 2

* metal : attempt to adapt SSM_SCAN for Mamba-2

* metal : fix SSM_SCAN pipeline scope

* metal : use log and exp instead of log1pf and expf in SSM_SCAN

* metal : remove unused arguments for SSM_SCAN

The max index is 31, so trimming the arguments is necessary.

* metal : add back n_seqs to SSM_SCAN args

Whoops, this is needed for the offset in the concatenated output.

* metal : fix SSM_SCAN state head offset

* metal : fix wrong number of tokens per sequence in SSM_SCAN

* ggml : remove unused fast broadcast path in GGML_MUL

This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.

* ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks

* convert : fix flake8 lint

* metal : fix confusion between ; and ,

* metal : add missing args for nb references in ssm_scan_f32_group

* metal : single-user mamba2 inference works

* kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.

* convert : avoid AutoConfig for Mamba and Mamba2 hparams

* kv-cache : allow context shift for recurrent models

* graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.

* ggml : fix mamba2 ssm scan when compiled with SVE

* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches

* cuda : implement ssm scan for Mamba2

There is still room for improvement, but it works!

* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2

* mamba : fix mismatched new and delete size for llm_build_mamba

Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON

* cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
compilade 2025-07-02 13:10:24 -04:00 committed by GitHub
parent e17991c466
commit 5d46babdc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 1075 additions and 311 deletions

View file

@ -4781,6 +4781,14 @@ class ARwkv7Model(Rwkv7Model):
class MambaModel(TextModel): class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self): def set_vocab(self):
vocab_size = self.hparams["vocab_size"] vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8 # Round vocab size to next multiple of 8
@ -4855,6 +4863,100 @@ class MambaModel(TextModel):
return [(new_name, data_torch)] return [(new_name, data_torch)]
@ModelBase.register("Mamba2ForCausalLM")
class Mamba2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA2
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
new_name = self.map_tensor_name(name)
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield (new_name, data_torch)
@ModelBase.register("CohereForCausalLM") @ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel): class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R model_arch = gguf.MODEL_ARCH.COMMAND_R
@ -6615,12 +6717,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# maybe we should fallback to text model's arch in that case, since not many models have both # maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {}) text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {}) vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0] arch = None
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
arch = arches[0]
elif "ssm_cfg" in hparams:
# For non-hf Mamba and Mamba2 models
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
# if "architectures" is found in the sub-config, use that instead # if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None: if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0] arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0] arch = vision_config["architectures"][0]
if arch is None:
raise ValueError("Failed to detect model architecture")
return arch return arch

View file

@ -2031,7 +2031,8 @@ extern "C" {
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C); struct ggml_tensor * C,
struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed // partition into non-overlapping windows with padding if needed
// example: // example:

View file

@ -8337,120 +8337,210 @@ void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32( static void ggml_compute_forward_ssm_scan_f32(
const ggml_compute_params * params, const ggml_compute_params * params,
ggml_tensor * dst) { ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // s const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // x const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // dt const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // A const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
const ggml_tensor * src4 = dst->src[4]; // B const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // C const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // dim
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t nh = src1->ne[1]; // n_head
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); // can't use ggml_nbytes because src1 is not necessarily contiguous
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); // allows optimizing the modulo since n_group should be a power of 2
// required for per-sequence offsets for states GGML_ASSERT((ng & -ng) == ng);
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
// rows per thread // heads per thread
const int dr = (nr + nth - 1)/nth; const int dh = (nh + nth - 1)/nth;
// row range for this thread // head range for this thread
const int ir0 = dr*ith; const int ih0 = dh*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ih1 = MIN(ih0 + dh, nh);
const int ir = ir1 - ir0;
#ifdef __ARM_FEATURE_SVE const int32_t * ids = (const int32_t *) src6->data;
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations for (int i3 = 0; i3 < ns; ++i3) {
if (i2 > 0) { s0 = s; } const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// d_inner for (int i2 = 0; i2 < nt; ++i2) {
for (int i1 = 0; i1 < ir; ++i1) { const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
float x_dt = x[i1] * dt_soft_plus; const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
svfloat32_t r1_vector = GGML_F32_VEC_ZERO; float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
for (int64_t k = 0; k < nc; k += svcntw()) { if (src3->ne[0] == 1) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); // n_head
t1 = exp_ps_sve(svptrue_b32(), t1); for (int h = ih0; h < ih1; ++h) {
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); // dim
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int ggml_f32_epr = svcntw();
const int ggml_f32_step = 1 * ggml_f32_epr;
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); const int np = (nc & ~(ggml_f32_step - 1));
}
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
}
} GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
} GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
for (int i = 0; i < np; i += ggml_f32_step) {
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
t0 = GGML_F32_VEC_ADD(t0, t1);
sum = GGML_F32_VEC_FMA(sum, t0, t2);
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
}
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#else #else
for (int i3 = 0; i3 < n_s; ++i3) { const int np = (nc & ~(GGML_F32_STEP - 1));
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
// use the output as the source for the next token-wise iterations GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
if (i2 > 0) { s0 = s; }
// d_inner GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
for (int i1 = 0; i1 < ir; ++i1) { GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; GGML_F32_VEC ax[GGML_F32_ARR];
float x_dt = x[i1] * dt_soft_plus; GGML_F32_VEC ay[GGML_F32_ARR];
float sumf = 0.0f; GGML_F32_VEC az[GGML_F32_ARR];
// d_state
for (int i0 = 0; i0 < nc; ++i0) { for (int i = 0; i < np; i += GGML_F32_STEP) {
int i = i0 + i1*nc; for (int j = 0; j < GGML_F32_ARR; j++) {
// state = prev_state * dA + dB * x ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
// y = rowwise_dotprod(state, C) az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
sumf += state * C[i0];
s[i] = state; ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#endif
#else
const int np = 0;
#endif
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf;
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
#if defined(__ARM_FEATURE_SVE)
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
// d_state
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
t1 = exp_ps_sve(svptrue_b32(), t1);
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
}
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
#else
float sumf = 0.0f;
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
// and also because expf is used within the loop.
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf;
#endif
} }
y[i1] = sumf;
} }
} }
// use the output as the source when it's not the first token-wise iteration
s0 = s;
} }
#endif }
} }
void ggml_compute_forward_ssm_scan( void ggml_compute_forward_ssm_scan(

View file

@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)

View file

@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = 0; i < np; i += ggml_f32_step) { for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
} }
// leftovers // leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = np; i < np2; i += ggml_f32_epr) { for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
} }
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) { if (np2 < n) {

View file

@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1); GGML_F32_VEC_STORE(y + i, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
} }
@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
for (int i = np; i < np2; i += ggml_f32_epr) { for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i); ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i); ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1); GGML_F32_VEC_STORE(y + i, ay1);
} }

View file

@ -3321,9 +3321,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true; return true;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
// (kernel only supports d_state == 128 && d_head % 16 == 0)
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
} else {
// Mamba
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT: case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16; return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:

View file

@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2) __global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src2_nb1, const int src2_nb2, const int src3_nb1,
float * __restrict__ dst, const int64_t L) { const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
GGML_UNUSED(src1_nb0); const int64_t s_off, const int64_t d_inner, const int64_t L) {
GGML_UNUSED(src2_nb0);
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int wid = tid / 32; const int wid = tid / 32;
const int wtid = tid % 32; const int wtid = tid % 32;
@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem; float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA; float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
const int stride_s0 = src0_nb1 / sizeof(float); const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb1 / sizeof(float); const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float); const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb1 / sizeof(float); const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb1 / sizeof(float); const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0; const int stride_s = stride_s0;
const int stride_y = stride_x; const int stride_y = d_inner;
// can N not be 16? for example 32? // can N not be 16? for example 32?
if (N == 16) { if (N == 16) {
@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
} }
} }
// assumes as many threads as d_state
template <int splitH, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int head_idx = (blockIdx.x * splitH) / d_head;
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[splitH];
// for the parallel accumulation
__shared__ float stateC[splitH * d_state];
#pragma unroll
for (int j = 0; j < splitH; j++) {
state[j] = s0_block[j * d_state + threadIdx.x];
}
for (int64_t i = 0; i < n_tok; i++) {
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// TODO: only calculate B and C once per head group
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
// across d_head
#pragma unroll
for (int j = 0; j < splitH; j++) {
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B * x_dt);
stateC[j * d_state + threadIdx.x] = state[j] * C;
}
__syncthreads();
// parallel accumulation for stateC
// TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
}
}
// write back the state
#pragma unroll
for (int j = 0; j < splitH; j++) {
s_block[j * d_state + threadIdx.x] = state[j];
}
}
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, const float * src4, const float * src5, const int32_t * src6, float * dst,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) { cudaStream_t stream) {
const int threads = 128; const int threads = 128;
// todo: consider D cannot be divided,does this situation exist? // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
GGML_ASSERT(D % threads == 0); if (src3_nb1 == sizeof(float)) {
const dim3 blocks(B, (D + threads - 1) / threads, 1); // Mamba-2
const int smem_size = (threads * (N + 1) * 2) * sizeof(float); if (d_state == 128) {
if (N == 16) { GGML_ASSERT(d_state % threads == 0);
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>( // NOTE: can be any power of two between 4 and 64
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, const int splitH = 16;
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=128.");
}
} else { } else {
GGML_ABORT("doesn't support N!=16."); // Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
} }
} }
@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C const struct ggml_tensor * src5 = dst->src[5]; // C
const struct ggml_tensor * src6 = dst->src[6]; // ids
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const int64_t nc = src0->ne[0]; // d_state const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner const int64_t nr = src0->ne[1]; // head_dim or 1
const int64_t n_t = src1->ne[1]; // number of tokens per sequence const int64_t nh = src1->ne[1]; // n_head
const int64_t n_s = src0->ne[2]; // number of sequences in the batch const int64_t ng = src4->ne[1]; // n_group
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); const int64_t s_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data; const float * src1_d = (const float *) src1->data;
@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * src3_d = (const float *) src3->data; const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data; const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data; const float * src5_d = (const float *) src5->data;
const int32_t * src6_d = (const int32_t *) src6->data;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
s_off, nc, nr, nh, ng, n_t, n_s, stream);
} }

View file

@ -513,26 +513,25 @@ typedef struct {
typedef struct { typedef struct {
int64_t d_state; int64_t d_state;
int64_t d_inner; int64_t d_inner;
int64_t n_head;
int64_t n_group;
int64_t n_seq_tokens; int64_t n_seq_tokens;
int64_t n_seqs; int64_t n_seqs;
uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb10; uint64_t nb03;
uint64_t nb11; uint64_t nb11;
uint64_t nb12; uint64_t nb12;
uint64_t nb13; uint64_t nb13;
uint64_t nb20;
uint64_t nb21; uint64_t nb21;
uint64_t nb22; uint64_t nb22;
uint64_t nb30;
uint64_t nb31; uint64_t nb31;
uint64_t nb40;
uint64_t nb41; uint64_t nb41;
uint64_t nb42; uint64_t nb42;
uint64_t nb50; uint64_t nb43;
uint64_t nb51; uint64_t nb51;
uint64_t nb52; uint64_t nb52;
uint64_t nb53;
} ggml_metal_kargs_ssm_scan; } ggml_metal_kargs_ssm_scan;
typedef struct { typedef struct {

View file

@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];
GGML_ASSERT(src3); GGML_ASSERT(src3);
GGML_ASSERT(src4); GGML_ASSERT(src4);
GGML_ASSERT(src5); GGML_ASSERT(src5);
GGML_ASSERT(src6);
size_t offs_src3 = 0; size_t offs_src3 = 0;
size_t offs_src4 = 0; size_t offs_src4 = 0;
size_t offs_src5 = 0; size_t offs_src5 = 0;
size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
const uint64_t nb30 = src3->nb[0]; const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1]; const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
const uint64_t nb40 = src4->nb[0]; const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1]; const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2]; const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
const uint64_t nb50 = src5->nb[0]; const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1]; const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2]; const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
const int64_t d_state = ne00; const int64_t d_state = ne00;
const int64_t d_inner = ne01; const int64_t d_inner = ne01;
const int64_t n_seq_tokens = ne11; const int64_t n_head = ne02;
const int64_t n_seqs = ne02; const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne12;
const int64_t n_seqs = ne13;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; id<MTLComputePipelineState> pipeline = nil;
if (ne30 == 1) {
// Mamba-2
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}
ggml_metal_kargs_ssm_scan args = { ggml_metal_kargs_ssm_scan args = {
/*.d_state =*/ d_state, /*.d_state =*/ d_state,
/*.d_inner =*/ d_inner, /*.d_inner =*/ d_inner,
/*.n_head =*/ n_head,
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens, /*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs, /*.n_seqs =*/ n_seqs,
/*.nb00 =*/ nb00, /*.nb01 =*/ nb01,
/*.nb01 =*/ nb01, /*.nb02 =*/ nb02,
/*.nb02 =*/ nb02, /*.nb03 =*/ nb03,
/*.nb10 =*/ nb10, /*.nb11 =*/ nb11,
/*.nb11 =*/ nb11, /*.nb12 =*/ nb12,
/*.nb12 =*/ nb12, /*.nb13 =*/ nb13,
/*.nb13 =*/ nb13, /*.nb21 =*/ nb21,
/*.nb20 =*/ nb20, /*.nb22 =*/ nb22,
/*.nb21 =*/ nb21, /*.nb31 =*/ nb31,
/*.nb22 =*/ nb22, /*.nb41 =*/ nb41,
/*.nb30 =*/ nb30, /*.nb42 =*/ nb42,
/*.nb31 =*/ nb31, /*.nb43 =*/ nb43,
/*.nb40 =*/ nb40, /*.nb51 =*/ nb51,
/*.nb41 =*/ nb41, /*.nb52 =*/ nb52,
/*.nb42 =*/ nb42, /*.nb53 =*/ nb53,
/*.nb50 =*/ nb50,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBytes:&args length:sizeof(args) atIndex:7]; [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break; } break;
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
{ {

View file

@ -1596,7 +1596,7 @@ kernel void kernel_ssm_conv_f32(
x[0] = sumf; x[0] = sumf;
} }
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32( kernel void kernel_ssm_scan_f32(
device const void * src0, device const void * src0,
device const void * src1, device const void * src1,
@ -1604,46 +1604,119 @@ kernel void kernel_ssm_scan_f32(
device const void * src3, device const void * src3,
device const void * src4, device const void * src4,
device const void * src5, device const void * src5,
device const void * src6,
device float * dst, device float * dst,
constant ggml_metal_kargs_ssm_scan & args, constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x; const int64_t i1 = 0;
const int64_t i3 = tgpig.y; const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state; const int64_t nc = args.d_state;
// const int64_t nr = args.d_inner; const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens; const int64_t n_t = args.n_seq_tokens;
// const int64_t n_s = args.n_seqs;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) { for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
if (i2 > 0) { const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
s0 = s; const float x_dt = x[0] * dt_soft_plus;
}
// i1 == 0
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f; float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) { for (int64_t i0 = 0; i0 < nc; ++i0) {
int64_t i = i0; const int64_t i = i0 + i1*nc;
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0]; sumf += state * C[i0];
s[i] = state; s[i] = state;
} }
y[0] = sumf; y[0] = sumf;
// recurse
s0 = s;
}
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
// recurse
s0 = s;
} }
} }

View file

@ -4837,7 +4837,6 @@ struct ggml_tensor * ggml_ssm_conv(
const int64_t n_s = sx->ne[2]; const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1? // TODO: maybe support other strides than 1?
// FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0); GGML_ASSERT(n_t >= 0);
@ -4860,36 +4859,49 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt, struct ggml_tensor * dt,
struct ggml_tensor * A, struct ggml_tensor * A,
struct ggml_tensor * B, struct ggml_tensor * B,
struct ggml_tensor * C) { struct ggml_tensor * C,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A)); GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A)); GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt)); GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C)); GGML_ASSERT(ggml_are_same_shape(B, C));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{ {
const int64_t d_state = s->ne[0]; const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1]; const int64_t head_dim = x->ne[0];
const int64_t n_seq_tokens = x->ne[1]; const int64_t n_head = x->ne[1];
const int64_t n_seqs = x->ne[2]; const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];
GGML_ASSERT(s->ne[2] == n_seqs); GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens); GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs); GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
if (A->ne[0] != 1) {
// Mamba-1 has more granular decay factors
GGML_ASSERT(A->ne[0] == d_state);
}
} }
// concatenated y + ssm_states // concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN; result->op = GGML_OP_SSM_SCAN;
result->src[0] = s; result->src[0] = s;
@ -4898,6 +4910,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A; result->src[3] = A;
result->src[4] = B; result->src[4] = B;
result->src[5] = C; result->src[5] = C;
result->src[6] = ids;
return result; return result;
} }

View file

@ -170,6 +170,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size" INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size" STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank" TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class WKV: class WKV:
@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum):
RWKV7 = auto() RWKV7 = auto()
ARWKV7 = auto() ARWKV7 = auto()
MAMBA = auto() MAMBA = auto()
MAMBA2 = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
COHERE2 = auto() COHERE2 = auto()
@ -429,6 +431,7 @@ class MODEL_TENSOR(IntEnum):
SSM_DT = auto() SSM_DT = auto()
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto() SSM_OUT = auto()
TIME_MIX_W0 = auto() TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto() TIME_MIX_W1 = auto()
@ -628,6 +631,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2", MODEL_ARCH.COHERE2: "cohere2",
@ -730,6 +734,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
@ -1714,6 +1719,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.MAMBA2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [ MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
@ -2497,6 +2515,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization # tokenization

View file

@ -861,6 +861,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_group_count(self, value: int) -> None:
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

View file

@ -477,7 +477,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm2", # nomic-bert "encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
), ),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
@ -574,6 +574,10 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D",
), ),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",

View file

@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_COHERE2, "cohere2" },
@ -170,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@ -1004,6 +1006,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
}, },
}, },
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{ {
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
{ {
@ -1761,6 +1779,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -1894,6 +1913,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
bool llm_arch_is_recurrent(const llm_arch & arch) { bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) { switch (arch) {
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7: case LLM_ARCH_RWKV7:

View file

@ -49,6 +49,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R, LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2, LLM_ARCH_COHERE2,
@ -174,6 +175,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK, LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS, LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE, LLM_KV_WKV_HEAD_SIZE,
@ -293,6 +295,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,

View file

@ -1466,7 +1466,7 @@ ggml_tensor * llm_graph_context::build_rs(
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
@ -1475,19 +1475,11 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states; // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
if (!avoid_copies) { // {state_size, kv_size} -> {state_size, n_seqs}
// copy states ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv ggml_build_forward_expand(gf, output_states);
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv) // copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@ -1518,10 +1510,10 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
ggml_tensor * llm_graph_context::build_rs( ggml_tensor * llm_graph_context::build_rs(
@ -1530,10 +1522,10 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr(); const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

View file

@ -424,6 +424,9 @@ struct llm_graph_params {
const llm_graph_cb & cb; const llm_graph_cb & cb;
}; };
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
struct llm_graph_context { struct llm_graph_context {
const llm_arch arch; const llm_arch arch;
@ -663,7 +666,7 @@ struct llm_graph_context {
uint32_t kv_head, uint32_t kv_head,
uint32_t kv_size, uint32_t kv_size,
int32_t rs_zero, int32_t rs_zero,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp() const; llm_graph_input_rs * build_rs_inp() const;
@ -673,7 +676,7 @@ struct llm_graph_context {
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rs( ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp, llm_graph_input_mem_hybrid * inp,
@ -681,7 +684,7 @@ struct llm_graph_context {
ggml_tensor * s, ggml_tensor * s,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rwkv_token_shift_load( ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp, llm_graph_input_rs * inp,

View file

@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const {
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; // Corresponds to Mamba's conv_states size
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
} }
uint32_t llama_hparams::n_embd_s() const { uint32_t llama_hparams::n_embd_s() const {

View file

@ -114,6 +114,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0; uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0; uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0; uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state space models // for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;

View file

@ -208,23 +208,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
} break; } break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
{ {
// FIXME const int64_t n_seq_tokens = 512;
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); const int64_t n_seqs = 3;
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs);
op_tensor = ggml_ssm_conv(ctx, conv_x, w); op_tensor = ggml_ssm_conv(ctx, conv_x, w);
} break; } break;
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
{ {
// FIXME // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2
const int64_t d_state = w->ne[0]; const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
const int64_t d_inner = w->ne[1]; const int64_t n_head = w->ne[1];
const int64_t head_dim = hparams.ssm_d_inner / n_head;
const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
const int64_t n_seq_tokens = 512; const int64_t n_seq_tokens = 512;
const int64_t n_seqs = 1; const int64_t n_seqs = 3;
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs);
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids);
} break; } break;
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:
{ {
@ -1081,6 +1085,38 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_MAMBA2:
{
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 24:
switch (hparams.n_embd) {
case 768: type = LLM_TYPE_SMALL; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 48:
switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_MEDIUM; break;
case 1536: type = LLM_TYPE_LARGE; break;
case 2048: type = LLM_TYPE_XL; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
case 64:
switch (hparams.n_embd) {
case 2560: type = LLM_TYPE_3B; break;
case 4096: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
} break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -3120,6 +3156,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
}
} break;
case LLM_ARCH_MAMBA2:
{
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
// only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
{
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed, duplicated to allow offloading
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
// out_proj // out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
} }
@ -4630,10 +4714,14 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
}
if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) { if (!classifier_labels.empty()) {
@ -9665,9 +9753,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
}; };
struct llm_build_mamba : public llm_graph_context { struct llm_build_mamba : public llm_graph_context {
const llama_model & model; llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
ggml_tensor * cur; ggml_tensor * cur;
ggml_tensor * inpL; ggml_tensor * inpL;
@ -9685,7 +9771,11 @@ struct llm_build_mamba : public llm_graph_context {
LLM_NORM_RMS, il); LLM_NORM_RMS, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); if (model.arch == LLM_ARCH_MAMBA2) {
cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
} else {
cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
}
if (il == n_layer - 1 && inp_out_ids) { if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids); cur = ggml_get_rows(ctx0, cur, inp_out_ids);
@ -9719,11 +9809,11 @@ struct llm_build_mamba : public llm_graph_context {
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
// TODO: split
ggml_tensor * build_mamba_layer( ggml_tensor * build_mamba_layer(
llm_graph_input_rs * inp, llm_graph_input_rs * inp,
ggml_cgraph * gf, ggml_cgraph * gf,
ggml_tensor * cur, ggml_tensor * cur,
const llama_model & model,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
@ -9734,6 +9824,8 @@ struct llm_build_mamba : public llm_graph_context {
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state; const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank; const int64_t dt_rank = hparams.ssm_dt_rank;
const int64_t n_head = d_inner;
const int64_t head_dim = 1;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
@ -9749,15 +9841,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// (ab)using the KV cache to store the states ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
ggml_tensor * conv = build_rs(
inp, gf, conv_states_all,
hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
ggml_tensor * ssm = build_rs(
inp, gf, ssm_states_all,
hparams.n_embd_s(), n_seqs);
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9806,8 +9891,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
// split // split
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) { if (ssm_dt_b_c_rms) {
@ -9820,23 +9905,36 @@ struct llm_build_mamba : public llm_graph_context {
dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = build_lora_mm(model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
// Custom operator to optimize the parallel associative scan cur = x;
// as described in the Annex D of the Mamba paper. x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); ggml_tensor * A = model.layers[il].ssm_a;
// use the states and the indices provided by build_recurrent_state
// (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
// Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states // store last states
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
ggml_cpy(ctx0, ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
// TODO: skip computing output earlier for unused tokens // TODO: skip computing output earlier for unused tokens
// {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
@ -9845,7 +9943,136 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
//cb(cur, "mamba_out", il); // cb(cur, "mamba_out", il);
return cur;
}
ggml_tensor * build_mamba2_layer(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_model & model,
const llama_ubatch & ubatch,
int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto kv_head = mctx_cur->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_head;
const int64_t n_group = hparams.ssm_n_group;
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
// split the above in three
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
// conv
{
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
// copy last (d_conv - 1) columns back into the state cache
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv,
ggml_view_1d(ctx0, conv_states_all,
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
// over d_conv columns at each stride in the 3rd dimension,
// then element-wise multiply that with the conv1d weight,
// then sum the elements of each row,
// (the last two steps are a dot product over rows (also doable with mul_mat))
// then permute away the ne[0] dimension,
// and then you're left with the resulting x tensor.
// For simultaneous sequences, all sequences need to have the same length.
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
// bias
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
xBC = ggml_silu(ctx0, xBC);
}
// ssm
{
// These correspond to V K Q in SSM/attention duality
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
// {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
ggml_tensor * A = model.layers[il].ssm_a;
// use the states and the indices provided by build_recurrent_state
// (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
// grouped RMS norm
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(model.layers[il].ssm_out, y);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
// cb(cur, "mamba_out", il);
return cur; return cur;
} }
@ -14668,6 +14895,7 @@ llm_graph_result_ptr llama_model::build_graph(
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf); llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
{ {
llm = std::make_unique<llm_build_mamba>(*this, params, gf); llm = std::make_unique<llm_build_mamba>(*this, params, gf);
} break; } break;
@ -14928,6 +15156,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM: case LLM_ARCH_BLOOM:
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5: case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER: case LLM_ARCH_T5ENCODER:

View file

@ -172,6 +172,7 @@ struct llama_layer {
struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * ffn_sub_norm = nullptr;
struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_cross = nullptr;
struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * attn_norm_enc = nullptr;
struct ggml_tensor * ssm_norm = nullptr;
// attention // attention
struct ggml_tensor * wq = nullptr; struct ggml_tensor * wq = nullptr;

View file

@ -2084,28 +2084,58 @@ struct test_ssm_scan : public test_case {
const ggml_type type; const ggml_type type;
const int64_t d_state; const int64_t d_state;
const int64_t d_inner; const int64_t head_dim;
const int64_t n_head;
const int64_t n_group;
const int64_t n_seq_tokens; const int64_t n_seq_tokens;
const int64_t n_seqs; const int64_t n_seqs;
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
} }
test_ssm_scan(ggml_type type = GGML_TYPE_F32, test_ssm_scan(ggml_type type = GGML_TYPE_F32,
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) int64_t d_state = 32,
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} int64_t head_dim = 1, // non-zero for Mamba-2
int64_t n_head = 32,
int64_t n_group = 1,
int64_t n_seq_tokens = 32,
int64_t n_seqs = 32)
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data()); ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data()); ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data()); ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out; return out;
} }
// similar to test_mul_mat_id
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
}
}; };
// GGML_OP_RWKV_WKV6 // GGML_OP_RWKV_WKV6
@ -4498,7 +4528,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));