Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	.github/workflows/build-linux-cross.yml
#	.github/workflows/build.yml
#	CODEOWNERS
#	ggml/CMakeLists.txt
#	ggml/src/ggml-cuda/fattn.cu
#	ggml/src/ggml-webgpu/CMakeLists.txt
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
#	tests/test-backend-ops.cpp
#	tests/test-chat-template.cpp
#	tools/llama-bench/llama-bench.cpp
#	tools/rpc/README.md
#	tools/server/README.md
This commit is contained in:
Concedo 2025-10-09 01:33:27 +08:00
commit b6f6338bba
32 changed files with 1556 additions and 636 deletions

View file

@ -2607,6 +2607,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.no_extra_bufts = true;
}
).set_env("LLAMA_ARG_NO_REPACK"));
add_opt(common_arg(
{"--no-host"},
"bypass host buffer allowing extra buffers to be used",
[](common_params & params) {
params.no_host = true;
}
).set_env("LLAMA_ARG_NO_HOST"));
add_opt(common_arg(
{"-ctk", "--cache-type-k"}, "TYPE",
string_format(
@ -3875,7 +3882,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
@ -3889,7 +3895,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
params.model.hf_file = "e5-small-v2-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;
@ -3903,7 +3908,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
params.model.hf_file = "gte-small-q8_0.gguf";
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
params.embd_normalize = 2;
params.n_ctx = 512;
params.verbose_prompt = true;

View file

@ -1141,6 +1141,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.use_extra_bufts = !params.no_extra_bufts;
mparams.no_host = params.no_host;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;

View file

@ -388,6 +388,7 @@ struct common_params {
bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
bool no_host = false; // bypass host buffer allowing extra buffers to be used
bool single_turn = false; // single turn chat conversation

View file

@ -8841,6 +8841,75 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2MOE
def set_gguf_parameters(self):
# set num_key_value_heads only for attention layers
self.hparams["num_key_value_heads"] = [
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
for layer_type in self.hparams["layer_types"]
]
super().set_gguf_parameters()
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
# cache for experts weights for merging
_experts_cache: dict[int, dict[str, Tensor]] = {}
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
expert_cache[name] = data_torch
expert_weights = ["w1", "w2", "w3"]
# not enough expert weights to merge
if len(expert_cache) < n_experts * len(expert_weights):
return []
tensors: list[tuple[str, Tensor]] = []
for w_name in expert_weights:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
datas.append(expert_cache[ename])
del expert_cache[ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
del self._experts_cache[bid]
return tensors
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
assert not self._experts_cache
@ModelBase.register("Lfm2VlForConditionalGeneration")
class LFM2VLModel(MmprojModel):
def __init__(self, *args, **kwargs):

View file

@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
// V /= S
const float S_inv = 1.0f/S;
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices

View file

@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
const int cc = ggml_cuda_info().devices[device].cc;
// TODO: temporary until support is extended
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
return BEST_FATTN_KERNEL_NONE;
}
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && ggml_cuda_should_use_wmma_fattn(cc)) { //kcpp: fix for rocwmma
return BEST_FATTN_KERNEL_WMMA_F16;

View file

@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
char base[256];
char name[256];
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
const char * suffix = "";
if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
char base[256];
char name[256];
if (op->src[3]->ne[0] == 1) {
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
const int nsg = (ne00 + 31)/32;
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
return res;
}
@ -918,6 +924,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_pad");
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
base,
has_mask,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_blk");
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
base,
nqptg,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const ggml_tensor * op,
@ -925,6 +1021,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@ -937,18 +1034,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
// do bounds checks for the mask?
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
"flash_attn_ext",
ggml_type_name(op->src[1]->type),
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
bc_mask,
ns10,
ns20,
nsg);
@ -964,6 +1066,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@ -983,6 +1088,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@ -1002,12 +1108,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg, nwg);
@ -1023,6 +1130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

View file

@ -135,6 +135,18 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
@ -142,6 +154,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@ -151,6 +164,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg);

View file

@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
{
return op->ne[3] == 1;
}
return true;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {

View file

@ -69,11 +69,20 @@
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT_BLK 200
#define FC_FLASH_ATTN_EXT 300
#define FC_FLASH_ATTN_EXT_VEC 400
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
// kernel argument structs
//
@ -178,6 +187,7 @@ typedef struct {
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
int64_t ne01;
int64_t ne02;
@ -243,6 +253,35 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;
typedef struct {
int32_t ne01;
int32_t ne30;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_blk;
typedef struct {
int32_t ne01;
int32_t ne02;
@ -261,6 +300,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
@ -295,6 +335,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
@ -572,32 +613,45 @@ typedef struct {
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t s_off;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t ns12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t ns21;
uint64_t nb22;
int64_t ne30;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t ns42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t ns52;
uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
int32_t ne00t;
int32_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb03;
int32_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows;
typedef struct {

View file

@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ggml_is_contiguous(node->src[1]), node->src[1]->name);
}
if (node->src[2]) {
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
ggml_is_contiguous(node->src[2]), node->src[2]->name);
}
if (node->src[3]) {
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
ggml_is_contiguous(node->src[3]), node->src[3]->name);
}
if (node) {
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
node->name);
@ -577,6 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
@ -906,23 +919,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = {
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nw0 = (args.ne00t + nth - 1)/nth;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
return 1;
}
@ -1117,7 +1138,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
@ -1172,25 +1193,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ns12 =*/ nb12/nb10,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.ns21 =*/ nb21/nb20,
/*.nb22 =*/ nb22,
/*.ne30 =*/ ne30,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.ns42 =*/ nb42/nb40,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.ns52 =*/ nb52/nb50,
/*.nb53 =*/ nb53,
/*.nb0 =*/ nb0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
@ -1206,13 +1238,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
if (ne30 == 1) {
// Mamba-2
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
} else {
GGML_ASSERT(d_inner == 1);
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
}
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
return 1;
}
@ -1273,26 +1299,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
// TODO: support
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
const int32_t nk00 = ne00;
int nth = 32; // SIMD width
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
int64_t nk0 = ne00;
if (ggml_is_quantized(op->src[0]->type)) {
nk0 = ne00/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne00/ggml_blck_size(op->type);
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk00) {
nrptg = (nth + nk00 - 1)/nk00;
nth = nk00;
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
@ -1300,10 +1323,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
}
}
nth = std::min(nth, nk00);
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.ne00 =*/ nk00,
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
@ -1321,12 +1345,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
return 1;
}
@ -1875,20 +1901,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
return (ne01 < 20) && (ne00 % 32 == 0);
}
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
if (has_kvpad) {
res += OP_FLASH_ATTN_EXT_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
}
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (!has_mask) {
return res;
}
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
size_t res = 0;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const int64_t nwg = 32;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
return res;
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@ -1910,8 +2023,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@ -1921,8 +2033,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne12 == ne22);
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
float scale;
float max_bias;
@ -1949,15 +2061,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne01 < 65536);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_pad = bid_dst;
bid_pad.offs += ggml_nbytes(op);
ggml_metal_buffer_id bid_blk = bid_pad;
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
ggml_metal_buffer_id bid_tmp = bid_blk;
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (has_mask) {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
ggml_metal_kargs_flash_attn_ext_blk args0 = {
/*.ne01 =*/ ne01,
/*.ne30 =*/ ne30,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
// 2*(2*ncpsg)
@ -2007,6 +2215,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@ -2023,24 +2232,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@ -2048,14 +2251,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg;
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nkpsg = 1*ncpsg;
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
bool need_sync = false;
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
@ -2120,6 +2371,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@ -2136,25 +2388,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
}
if (op->src[4]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
const size_t smem = FATTN_SMEM(nsg);
@ -2162,23 +2406,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
if (nwg == 1) {
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
// using 1 workgroup -> write the result directly into dst
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
} else {
// sanity checks
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
// write the results from each workgroup into a temp buffer
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);

View file

@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
// return true if we should use the FA vector kernel for this op
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);

View file

@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
}
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
default:
break;

File diff suppressed because it is too large Load diff

View file

@ -407,6 +407,7 @@ class MODEL_ARCH(IntEnum):
SMOLLM3 = auto()
GPT_OSS = auto()
LFM2 = auto()
LFM2MOE = auto()
DREAM = auto()
SMALLTHINKER = auto()
LLADA = auto()
@ -749,6 +750,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.SMOLLM3: "smollm3",
MODEL_ARCH.GPT_OSS: "gpt-oss",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.LFM2MOE: "lfm2moe",
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
@ -2698,6 +2700,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.OUTPUT,
],
MODEL_ARCH.LFM2MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.SHORTCONV_CONV,
MODEL_TENSOR.SHORTCONV_INPROJ,
MODEL_TENSOR.SHORTCONV_OUTPROJ,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_NORM, # operator_norm
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.SMALLTHINKER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -358,6 +358,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.router", # openai-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
"model.layers.{bid}.feed_forward.gate", # lfm2moe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -367,6 +368,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
),
# Feed-forward up

View file

@ -299,6 +299,7 @@ extern "C" {
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
bool no_host; // bypass host buffer allowing extra buffers to be used
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

View file

@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
@ -2104,6 +2105,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT, "output" },
}
},
{
LLM_ARCH_LFM2MOE,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_SMALLTHINKER,
{
@ -2493,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
return true;
default:

View file

@ -97,6 +97,7 @@ enum llm_arch {
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
LLM_ARCH_LFM2MOE,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,

View file

@ -590,7 +590,7 @@ int32_t llm_chat_apply_template(
ss << message->content << "<|end_of_text|>\n";
}
if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
ss << "<|start_of_role|>assistant<|end_of_role|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
// GigaChat template

View file

@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch, false);
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}
if (ubatch.n_tokens == 0) {

View file

@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch, false);
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}
if (ubatch.n_tokens == 0) {
@ -859,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) {
// single sequence
seq_rm(dest_seq_id, -1, -1);
if (cell_count == 0) {
return true;
}
llama_batch_allocr balloc(hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);

View file

@ -119,6 +119,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
@ -315,7 +316,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
}
// CPU: ACCEL -> GPU host -> CPU extra -> CPU
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts) {
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts, bool no_host) {
buft_list_t buft_list;
// add ACCEL buffer types
@ -336,11 +337,13 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
// generally, this will be done using the first device in the list
// a better approach would be to handle this on a weight-by-weight basis using the offload_op
// function of the device to determine if it would benefit from being stored in a host buffer
for (auto * dev : devices) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
if (buft) {
buft_list.emplace_back(dev, buft);
break;
if (!no_host) {
for (auto * dev : devices) {
ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
if (buft) {
buft_list.emplace_back(dev, buft);
break;
}
}
}
@ -1998,14 +2001,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
hparams.n_layer_dense_lead = hparams.n_layer;
switch (hparams.n_ff()) {
case 4608: type = LLM_TYPE_350M; break;
case 6912: type = LLM_TYPE_700M; break;
case 8192: type = LLM_TYPE_1_2B; break;
case 10752: type = LLM_TYPE_2_6B; break;
default: type = LLM_TYPE_UNKNOWN;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_LFM2MOE:
{
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
type = LLM_TYPE_8B_A1B;
} break;
case LLM_ARCH_SMALLTHINKER:
{
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
@ -2088,7 +2106,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
// build a list of buffer types for the CPU and GPU devices
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts);
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host);
for (auto * dev : devices) {
buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
// add CPU buffer types as a fallback
@ -5870,6 +5888,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
@ -5881,11 +5900,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// ffn is same for transformer and conv layers
const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead);
// ffn/moe is same for transformer and conv layers
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
if (is_moe_layer) {
GGML_ASSERT(n_expert && n_expert_used);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
} else { // dense
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
// for operator_norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
@ -6367,7 +6398,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
}
if (arch == LLM_ARCH_SMALLTHINKER) {
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
}
@ -18662,6 +18693,8 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
const bool is_moe_layer = il >= static_cast<int>(hparams.n_layer_dense_lead);
auto * prev_cur = cur;
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.operator_norm", il);
@ -18676,7 +18709,16 @@ struct llm_build_lfm2 : public llm_graph_context {
}
cur = ggml_add(ctx0, prev_cur, cur);
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(ffn_norm_out, "model.layers.{}.ffn_norm", il);
ggml_tensor * ffn_out = is_moe_layer ?
build_moe_feed_forward(ffn_norm_out, il) :
build_dense_feed_forward(ffn_norm_out, il);
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_out);
}
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
@ -18691,23 +18733,32 @@ struct llm_build_lfm2 : public llm_graph_context {
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_feed_forward(ggml_tensor * cur,
int il) const {
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.ffn_norm", il);
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur,
int il) const {
return build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
il);
}
ggml_tensor * build_dense_feed_forward(ggml_tensor * cur,
int il) const {
GGML_ASSERT(!model.layers[il].ffn_up_b);
GGML_ASSERT(!model.layers[il].ffn_gate_b);
GGML_ASSERT(!model.layers[il].ffn_down_b);
cur = build_ffn(cur,
return build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "model.layers.{}.feed_forward.w2", il);
return cur;
}
ggml_tensor * build_attn_block(ggml_tensor * cur,
@ -19877,6 +19928,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
} break;
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
llm = std::make_unique<llm_build_lfm2>(*this, params);
} break;
@ -19927,6 +19979,7 @@ llama_model_params llama_model_default_params() {
/*.use_mlock =*/ false,
/*.check_tensors =*/ false,
/*.use_extra_bufts =*/ true,
/*.no_host =*/ false,
};
return result;
@ -20098,6 +20151,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_OPENAI_MOE:
case LLM_ARCH_HUNYUAN_DENSE:
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_SMALLTHINKER:
case LLM_ARCH_GLM4_MOE:
case LLM_ARCH_SEED_OSS:

View file

@ -107,6 +107,7 @@ enum llm_type {
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_A13B,
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air

View file

@ -249,10 +249,9 @@ struct mtmd_context {
} else if (proj == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
tok_ov_img_start = {lookup_token("\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
tok_ov_img_start = {lookup_token("\n\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
tok_ov_img_end = {lookup_token("<fake_token_around_image>")};
tok_row_end = {lookup_token("\n")};
img_beg = "<fake_token_around_image>";
sli_img_start_tmpl = "<fake_token_around_image><row_%d_col_%d>";
} else if (proj == PROJECTOR_TYPE_PIXTRAL) {

Binary file not shown.

View file

@ -1937,7 +1937,7 @@ private:
void cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
auto rm_func = [id_target](const server_task & task) {
return task.id_target == id_target;
return task.id == id_target;
};
queue_tasks.erase(
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
@ -3676,6 +3676,20 @@ struct server_context {
alora_disabled_id = enabled_loras[0];
}
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) ||
llama_model_is_hybrid(model) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
@ -3700,6 +3714,11 @@ struct server_context {
slot.n_prompt_tokens_processed++;
slot.n_past++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
break;
}
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@ -3730,6 +3749,39 @@ struct server_context {
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
if (do_checkpoint) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
}
}
@ -3853,40 +3905,6 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
const bool do_checkpoint =
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
@ -4184,6 +4202,7 @@ int main(int argc, char ** argv) {
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
"/models",
"/v1/models",
"/api/tags"
@ -5232,6 +5251,7 @@ int main(int argc, char ** argv) {
// register API routes
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check)
svr->Get (params.api_prefix + "/metrics", handle_metrics);
svr->Get (params.api_prefix + "/props", handle_props);
svr->Post(params.api_prefix + "/props", handle_props_change);

View file

@ -1,8 +1,9 @@
<script lang="ts">
import { Search, SquarePen, X } from '@lucide/svelte';
import { Search, SquarePen, X, Download, Upload } from '@lucide/svelte';
import { KeyboardShortcutInfo } from '$lib/components/app';
import { Button } from '$lib/components/ui/button';
import { Input } from '$lib/components/ui/input';
import { exportAllConversations, importConversations } from '$lib/stores/chat.svelte';
interface Props {
handleMobileSidebarItemClick: () => void;
@ -77,5 +78,34 @@
<KeyboardShortcutInfo keys={['cmd', 'k']} />
</Button>
<Button
class="w-full justify-start text-sm"
onclick={() => {
importConversations().catch((err) => {
console.error('Import failed:', err);
// Optional: show toast or dialog
});
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<Upload class="h-4 w-4" />
Import conversations
</div>
</Button>
<Button
class="w-full justify-start text-sm"
onclick={() => {
exportAllConversations();
}}
variant="ghost"
>
<div class="flex items-center gap-2">
<Download class="h-4 w-4" />
Export all conversations
</div>
</Button>
{/if}
</div>

View file

@ -1,6 +1,7 @@
<script lang="ts">
import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte';
import { Trash2, Pencil, MoreHorizontal, Download } from '@lucide/svelte';
import { ActionDropdown } from '$lib/components/app';
import { downloadConversation } from '$lib/stores/chat.svelte';
import { onMount } from 'svelte';
interface Props {
@ -101,6 +102,15 @@
onclick: handleEdit,
shortcut: ['shift', 'cmd', 'e']
},
{
icon: Download,
label: 'Export',
onclick: (e) => {
e.stopPropagation();
downloadConversation(conversation.id);
},
shortcut: ['shift', 'cmd', 's']
},
{
icon: Trash2,
label: 'Delete',

View file

@ -6,6 +6,8 @@ import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/u
import { browser } from '$app/environment';
import { goto } from '$app/navigation';
import { extractPartialThinking } from '$lib/utils/thinking';
import { toast } from 'svelte-sonner';
import type { ExportedConversations } from '$lib/types/database';
/**
* ChatStore - Central state management for chat conversations and AI interactions
@ -951,6 +953,166 @@ class ChatStore {
}
}
/**
* Downloads a conversation as JSON file
* @param convId - The conversation ID to download
*/
async downloadConversation(convId: string): Promise<void> {
if (!this.activeConversation || this.activeConversation.id !== convId) {
// Load the conversation if not currently active
const conversation = await DatabaseStore.getConversation(convId);
if (!conversation) return;
const messages = await DatabaseStore.getConversationMessages(convId);
const conversationData = {
conv: conversation,
messages
};
this.triggerDownload(conversationData);
} else {
// Use current active conversation data
const conversationData: ExportedConversations = {
conv: this.activeConversation!,
messages: this.activeMessages
};
this.triggerDownload(conversationData);
}
}
/**
* Triggers file download in browser
* @param data - Data to download (expected: { conv: DatabaseConversation, messages: DatabaseMessage[] })
* @param filename - Optional filename
*/
private triggerDownload(data: ExportedConversations, filename?: string): void {
const conversation =
'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined;
if (!conversation) {
console.error('Invalid data: missing conversation');
return;
}
const conversationName = conversation.name ? conversation.name.trim() : '';
const convId = conversation.id || 'unknown';
const truncatedSuffix = conversationName
.toLowerCase()
.replace(/[^a-z0-9]/gi, '_')
.replace(/_+/g, '_')
.substring(0, 20);
const downloadFilename = filename || `conversation_${convId}_${truncatedSuffix}.json`;
const conversationJson = JSON.stringify(data, null, 2);
const blob = new Blob([conversationJson], {
type: 'application/json'
});
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = downloadFilename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
/**
* Exports all conversations with their messages as a JSON file
*/
async exportAllConversations(): Promise<void> {
try {
const allConversations = await DatabaseStore.getAllConversations();
if (allConversations.length === 0) {
throw new Error('No conversations to export');
}
const allData: ExportedConversations = await Promise.all(
allConversations.map(async (conv) => {
const messages = await DatabaseStore.getConversationMessages(conv.id);
return { conv, messages };
})
);
const blob = new Blob([JSON.stringify(allData, null, 2)], {
type: 'application/json'
});
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
toast.success(`All conversations (${allConversations.length}) prepared for download`);
} catch (err) {
console.error('Failed to export conversations:', err);
throw err;
}
}
/**
* Imports conversations from a JSON file.
* Supports both single conversation (object) and multiple conversations (array).
* Uses DatabaseStore for safe, encapsulated data access
*/
async importConversations(): Promise<void> {
return new Promise((resolve, reject) => {
const input = document.createElement('input');
input.type = 'file';
input.accept = '.json';
input.onchange = async (e) => {
const file = (e.target as HTMLInputElement)?.files?.[0];
if (!file) {
reject(new Error('No file selected'));
return;
}
try {
const text = await file.text();
const parsedData = JSON.parse(text);
let importedData: ExportedConversations;
if (Array.isArray(parsedData)) {
importedData = parsedData;
} else if (
parsedData &&
typeof parsedData === 'object' &&
'conv' in parsedData &&
'messages' in parsedData
) {
// Single conversation object
importedData = [parsedData];
} else {
throw new Error(
'Invalid file format: expected array of conversations or single conversation object'
);
}
const result = await DatabaseStore.importConversations(importedData);
// Refresh UI
await this.loadConversations();
toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`);
resolve(undefined);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : 'Unknown error';
console.error('Failed to import conversations:', err);
toast.error('Import failed', {
description: message
});
reject(new Error(`Import failed: ${message}`));
}
};
input.click();
});
}
/**
* Deletes a conversation and all its messages
* @param convId - The conversation ID to delete
@ -1427,6 +1589,9 @@ export const isInitialized = () => chatStore.isInitialized;
export const maxContextError = () => chatStore.maxContextError;
export const createConversation = chatStore.createConversation.bind(chatStore);
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
export const exportAllConversations = chatStore.exportAllConversations.bind(chatStore);
export const importConversations = chatStore.importConversations.bind(chatStore);
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
export const sendMessage = chatStore.sendMessage.bind(chatStore);
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);

View file

@ -346,4 +346,39 @@ export class DatabaseStore {
): Promise<void> {
await db.messages.update(id, updates);
}
/**
* Imports multiple conversations and their messages.
* Skips conversations that already exist.
*
* @param data - Array of { conv, messages } objects
*/
static async importConversations(
data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[]
): Promise<{ imported: number; skipped: number }> {
let importedCount = 0;
let skippedCount = 0;
return await db.transaction('rw', [db.conversations, db.messages], async () => {
for (const item of data) {
const { conv, messages } = item;
const existing = await db.conversations.get(conv.id);
if (existing) {
console.warn(`Conversation "${conv.name}" already exists, skipping...`);
skippedCount++;
continue;
}
await db.conversations.add(conv);
for (const msg of messages) {
await db.messages.put(msg);
}
importedCount++;
}
return { imported: importedCount, skipped: skippedCount };
});
}
}

View file

@ -54,3 +54,18 @@ export interface DatabaseMessage {
timings?: ChatMessageTimings;
model?: string;
}
/**
* Represents a single conversation with its associated messages,
* typically used for import/export operations.
*/
export type ExportedConversation = {
conv: DatabaseConversation;
messages: DatabaseMessage[];
};
/**
* Type representing one or more exported conversations.
* Can be a single conversation object or an array of them.
*/
export type ExportedConversations = ExportedConversation | ExportedConversation[];