mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-18 23:49:46 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # .github/workflows/build-linux-cross.yml # .github/workflows/build.yml # CODEOWNERS # ggml/CMakeLists.txt # ggml/src/ggml-cuda/fattn.cu # ggml/src/ggml-webgpu/CMakeLists.txt # ggml/src/ggml-webgpu/ggml-webgpu.cpp # ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl # tests/test-backend-ops.cpp # tests/test-chat-template.cpp # tools/llama-bench/llama-bench.cpp # tools/rpc/README.md # tools/server/README.md
This commit is contained in:
commit
b6f6338bba
32 changed files with 1556 additions and 636 deletions
|
|
@ -2607,6 +2607,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.no_extra_bufts = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_NO_REPACK"));
|
||||
add_opt(common_arg(
|
||||
{"--no-host"},
|
||||
"bypass host buffer allowing extra buffers to be used",
|
||||
[](common_params & params) {
|
||||
params.no_host = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_NO_HOST"));
|
||||
add_opt(common_arg(
|
||||
{"-ctk", "--cache-type-k"}, "TYPE",
|
||||
string_format(
|
||||
|
|
@ -3875,7 +3882,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
|
||||
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
|
|
@ -3889,7 +3895,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
|
||||
params.model.hf_file = "e5-small-v2-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
|
|
@ -3903,7 +3908,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
|
||||
params.model.hf_file = "gte-small-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
|
|
|
|||
|
|
@ -1141,6 +1141,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
mparams.no_host = params.no_host;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
|
|
|
|||
|
|
@ -388,6 +388,7 @@ struct common_params {
|
|||
bool check_tensors = false; // validate tensor data
|
||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
|
|
|
|||
|
|
@ -8841,6 +8841,75 @@ class LFM2Model(TextModel):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2MoeForCausalLM")
|
||||
class LFM2MoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2MOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# set num_key_value_heads only for attention layers
|
||||
self.hparams["num_key_value_heads"] = [
|
||||
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
|
||||
for layer_type in self.hparams["layer_types"]
|
||||
]
|
||||
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
|
||||
|
||||
# cache for experts weights for merging
|
||||
_experts_cache: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# conv op requires 2d tensor
|
||||
if 'conv.conv' in name:
|
||||
data_torch = data_torch.squeeze(1)
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
expert_cache[name] = data_torch
|
||||
expert_weights = ["w1", "w2", "w3"]
|
||||
|
||||
# not enough expert weights to merge
|
||||
if len(expert_cache) < n_experts * len(expert_weights):
|
||||
return []
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for w_name in expert_weights:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
|
||||
datas.append(expert_cache[ename])
|
||||
del expert_cache[ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
del self._experts_cache[bid]
|
||||
return tensors
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
assert not self._experts_cache
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2VlForConditionalGeneration")
|
||||
class LFM2VLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f/S;
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
|
|
|
|||
|
|
@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
// TODO: temporary until support is extended
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
|
||||
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (GGML_CUDA_CC_IS_AMD(cc) && ggml_cuda_should_use_wmma_fattn(cc)) { //kcpp: fix for rocwmma
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
|
|
|
|||
|
|
@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
|||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
||||
const char * suffix = "";
|
||||
|
||||
if (op->src[1]->ne[0] % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
|
|
@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
|||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
snprintf(name, 256, "%s", base);
|
||||
const int nsg = (ne00 + 31)/32;
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
|
|
@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
|
|||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -918,6 +924,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_pad");
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
||||
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_blk");
|
||||
|
||||
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
||||
base,
|
||||
nqptg,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
||||
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const ggml_tensor * op,
|
||||
|
|
@ -925,6 +1021,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
|
|
@ -937,18 +1034,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|||
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
||||
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
||||
|
||||
// do bounds checks for the mask?
|
||||
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
bc_mask,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg);
|
||||
|
|
@ -964,6 +1066,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||
|
|
@ -983,6 +1088,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
|
@ -1002,12 +1108,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg, nwg);
|
||||
|
|
@ -1023,6 +1130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
||||
|
|
|
|||
|
|
@ -135,6 +135,18 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
|
|
@ -142,6 +154,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
|||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
|
|
@ -151,6 +164,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
|
|
|
|||
|
|
@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
};
|
||||
}
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
|
|
|
|||
|
|
@ -69,11 +69,20 @@
|
|||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
#define FC_FLASH_ATTN_EXT_PAD 100
|
||||
#define FC_FLASH_ATTN_EXT_BLK 200
|
||||
#define FC_FLASH_ATTN_EXT 300
|
||||
#define FC_FLASH_ATTN_EXT_VEC 400
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
||||
#define FC_MUL_MV 600
|
||||
#define FC_MUL_MM 700
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||
#define OP_FLASH_ATTN_EXT_NCPSG 64
|
||||
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
|
|
@ -178,6 +187,7 @@ typedef struct {
|
|||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t nk0;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
|
|
@ -243,6 +253,35 @@ typedef struct {
|
|||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_pad;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne30;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_blk;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
|
|
@ -261,6 +300,7 @@ typedef struct {
|
|||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
|
|
@ -295,6 +335,7 @@ typedef struct {
|
|||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
|
|
@ -572,32 +613,45 @@ typedef struct {
|
|||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
uint64_t s_off;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t ns12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb20;
|
||||
uint64_t nb21;
|
||||
uint64_t ns21;
|
||||
uint64_t nb22;
|
||||
int64_t ne30;
|
||||
uint64_t nb31;
|
||||
uint64_t nb41;
|
||||
uint64_t nb42;
|
||||
uint64_t ns42;
|
||||
uint64_t nb43;
|
||||
uint64_t nb51;
|
||||
uint64_t nb52;
|
||||
uint64_t ns52;
|
||||
uint64_t nb53;
|
||||
uint64_t nb0;
|
||||
} ggml_metal_kargs_ssm_scan;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int32_t ne00t;
|
||||
int32_t ne00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int64_t ne10;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_get_rows;
|
||||
|
||||
typedef struct {
|
||||
|
|
|
|||
|
|
@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
||||
|
||||
|
|
@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
||||
}
|
||||
if (node->src[2]) {
|
||||
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
||||
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
||||
}
|
||||
if (node->src[3]) {
|
||||
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
||||
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
||||
}
|
||||
if (node) {
|
||||
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
node->name);
|
||||
|
|
@ -577,6 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
|
|
@ -906,23 +919,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
ggml_metal_kargs_get_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nw0 = (args.ne00t + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -1117,7 +1138,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
||||
|
||||
|
|
@ -1172,25 +1193,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ns12 =*/ nb12/nb10,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.ns21 =*/ nb21/nb20,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.ne30 =*/ ne30,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.ns42 =*/ nb42/nb40,
|
||||
/*.nb43 =*/ nb43,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.ns52 =*/ nb52/nb50,
|
||||
/*.nb53 =*/ nb53,
|
||||
/*.nb0 =*/ nb0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
|
||||
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
|
|
@ -1206,13 +1238,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
||||
}
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -1273,26 +1299,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
||||
|
||||
// TODO: support
|
||||
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
|
||||
const int32_t nk00 = ne00;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
int64_t nk0 = ne00;
|
||||
if (ggml_is_quantized(op->src[0]->type)) {
|
||||
nk0 = ne00/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne00/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk00) {
|
||||
nrptg = (nth + nk00 - 1)/nk00;
|
||||
nth = nk00;
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
|
|
@ -1300,10 +1323,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
}
|
||||
|
||||
nth = std::min(nth, nk00);
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ nk00,
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
|
|
@ -1321,12 +1345,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
|||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -1875,20 +1901,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
|||
return (ne01 < 20) && (ne00 % 32 == 0);
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (!has_mask) {
|
||||
return res;
|
||||
}
|
||||
|
||||
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
||||
|
||||
// this optimization is not useful for the vector kernels
|
||||
if (is_vec) {
|
||||
return res;
|
||||
}
|
||||
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||
|
||||
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
||||
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
||||
|
||||
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int64_t nwg = 32;
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
const int64_t ne01 = op->src[0]->ne[1];
|
||||
const int64_t ne02 = op->src[0]->ne[2];
|
||||
const int64_t ne03 = op->src[0]->ne[3];
|
||||
const int64_t ne20 = op->src[2]->ne[0];
|
||||
size_t res = 0;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const int64_t nwg = 32;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
|
@ -1910,8 +2023,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
||||
|
|
@ -1921,8 +2033,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_ASSERT(ne12 == ne22);
|
||||
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
|
|
@ -1949,15 +2061,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
GGML_ASSERT(ne01 < 65536);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
||||
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
||||
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||
bid_pad.offs += ggml_nbytes(op);
|
||||
|
||||
ggml_metal_buffer_id bid_blk = bid_pad;
|
||||
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_blk;
|
||||
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
|
||||
|
||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// half8x8 kernel
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
bool need_sync = false;
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne30 =*/ ne30,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb32 =*/ nb32,
|
||||
/*.nb33 =*/ nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
||||
|
||||
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
||||
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
||||
|
||||
// 2*(2*ncpsg)
|
||||
|
|
@ -2007,6 +2215,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
|
|
@ -2023,24 +2232,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
|
|
@ -2048,14 +2251,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
#undef FATTN_SMEM
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int64_t nkpsg = 1*ncpsg;
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int nkpsg = 1*ncpsg;
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
bool need_sync = false;
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
|
|
@ -2120,6 +2371,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
|
|
@ -2136,25 +2388,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
|
@ -2162,23 +2406,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
||||
|
||||
if (nwg == 1) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
||||
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
} else {
|
||||
// sanity checks
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
||||
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
// write the results from each workgroup into a temp buffer
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
|||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
|||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -407,6 +407,7 @@ class MODEL_ARCH(IntEnum):
|
|||
SMOLLM3 = auto()
|
||||
GPT_OSS = auto()
|
||||
LFM2 = auto()
|
||||
LFM2MOE = auto()
|
||||
DREAM = auto()
|
||||
SMALLTHINKER = auto()
|
||||
LLADA = auto()
|
||||
|
|
@ -749,6 +750,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.SMOLLM3: "smollm3",
|
||||
MODEL_ARCH.GPT_OSS: "gpt-oss",
|
||||
MODEL_ARCH.LFM2: "lfm2",
|
||||
MODEL_ARCH.LFM2MOE: "lfm2moe",
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||
MODEL_ARCH.LLADA: "llada",
|
||||
|
|
@ -2698,6 +2700,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
],
|
||||
MODEL_ARCH.LFM2MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.SHORTCONV_CONV,
|
||||
MODEL_TENSOR.SHORTCONV_INPROJ,
|
||||
MODEL_TENSOR.SHORTCONV_OUTPROJ,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM, # operator_norm
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.SMALLTHINKER: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -358,6 +358,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.router", # openai-moe
|
||||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
|
@ -367,6 +368,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
|
|
|
|||
|
|
@ -299,6 +299,7 @@ extern "C" {
|
|||
bool use_mlock; // force system to keep model in RAM
|
||||
bool check_tensors; // validate model tensor data
|
||||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_LLADA, "llada" },
|
||||
|
|
@ -2104,6 +2105,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LFM2MOE,
|
||||
{
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
|
||||
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
||||
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
{
|
||||
|
|
@ -2493,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|||
case LLM_ARCH_PLAMO2:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
return true;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ enum llm_arch {
|
|||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_LFM2MOE,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_LLADA,
|
||||
|
|
|
|||
|
|
@ -590,7 +590,7 @@ int32_t llm_chat_apply_template(
|
|||
ss << message->content << "<|end_of_text|>\n";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
|
||||
ss << "<|start_of_role|>assistant<|end_of_role|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
|
||||
// GigaChat template
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
@ -859,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
|||
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
seq_rm(dest_seq_id, -1, -1);
|
||||
|
||||
if (cell_count == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
||||
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
|
||||
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
||||
case LLM_TYPE_A13B: return "A13B";
|
||||
case LLM_TYPE_8B_A1B: return "8B.A1B";
|
||||
case LLM_TYPE_21B_A3B: return "21B.A3B";
|
||||
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
||||
case LLM_TYPE_106B_A12B: return "106B.A12B";
|
||||
|
|
@ -315,7 +316,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
|
|||
}
|
||||
|
||||
// CPU: ACCEL -> GPU host -> CPU extra -> CPU
|
||||
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts) {
|
||||
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts, bool no_host) {
|
||||
buft_list_t buft_list;
|
||||
|
||||
// add ACCEL buffer types
|
||||
|
|
@ -336,11 +337,13 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
|
|||
// generally, this will be done using the first device in the list
|
||||
// a better approach would be to handle this on a weight-by-weight basis using the offload_op
|
||||
// function of the device to determine if it would benefit from being stored in a host buffer
|
||||
for (auto * dev : devices) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
if (buft) {
|
||||
buft_list.emplace_back(dev, buft);
|
||||
break;
|
||||
if (!no_host) {
|
||||
for (auto * dev : devices) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
|
||||
if (buft) {
|
||||
buft_list.emplace_back(dev, buft);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1998,14 +2001,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
|
||||
}
|
||||
hparams.n_layer_dense_lead = hparams.n_layer;
|
||||
switch (hparams.n_ff()) {
|
||||
case 4608: type = LLM_TYPE_350M; break;
|
||||
case 6912: type = LLM_TYPE_700M; break;
|
||||
case 8192: type = LLM_TYPE_1_2B; break;
|
||||
case 10752: type = LLM_TYPE_2_6B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
|
||||
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
|
||||
}
|
||||
|
||||
type = LLM_TYPE_8B_A1B;
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
|
@ -2088,7 +2106,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
|
||||
|
||||
// build a list of buffer types for the CPU and GPU devices
|
||||
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts);
|
||||
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host);
|
||||
for (auto * dev : devices) {
|
||||
buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
|
||||
// add CPU buffer types as a fallback
|
||||
|
|
@ -5870,6 +5888,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
|
|
@ -5881,11 +5900,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
// ffn is same for transformer and conv layers
|
||||
|
||||
const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead);
|
||||
|
||||
// ffn/moe is same for transformer and conv layers
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
if (is_moe_layer) {
|
||||
GGML_ASSERT(n_expert && n_expert_used);
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
} else { // dense
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
|
||||
// for operator_norm
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
|
@ -6367,7 +6398,7 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_SMALLTHINKER) {
|
||||
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
}
|
||||
|
|
@ -18662,6 +18693,8 @@ struct llm_build_lfm2 : public llm_graph_context {
|
|||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_moe_layer = il >= static_cast<int>(hparams.n_layer_dense_lead);
|
||||
|
||||
auto * prev_cur = cur;
|
||||
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.operator_norm", il);
|
||||
|
|
@ -18676,7 +18709,16 @@ struct llm_build_lfm2 : public llm_graph_context {
|
|||
}
|
||||
|
||||
cur = ggml_add(ctx0, prev_cur, cur);
|
||||
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
|
||||
|
||||
auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(ffn_norm_out, "model.layers.{}.ffn_norm", il);
|
||||
|
||||
ggml_tensor * ffn_out = is_moe_layer ?
|
||||
build_moe_feed_forward(ffn_norm_out, il) :
|
||||
build_dense_feed_forward(ffn_norm_out, il);
|
||||
cb(ffn_norm_out, "model.layers.{}.ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_out);
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
|
@ -18691,23 +18733,32 @@ struct llm_build_lfm2 : public llm_graph_context {
|
|||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * build_feed_forward(ggml_tensor * cur,
|
||||
int il) const {
|
||||
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.ffn_norm", il);
|
||||
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur,
|
||||
int il) const {
|
||||
return build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func),
|
||||
il);
|
||||
}
|
||||
|
||||
ggml_tensor * build_dense_feed_forward(ggml_tensor * cur,
|
||||
int il) const {
|
||||
GGML_ASSERT(!model.layers[il].ffn_up_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_gate_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_down_b);
|
||||
cur = build_ffn(cur,
|
||||
return build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "model.layers.{}.feed_forward.w2", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_attn_block(ggml_tensor * cur,
|
||||
|
|
@ -19877,6 +19928,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_lfm2>(*this, params);
|
||||
} break;
|
||||
|
|
@ -19927,6 +19979,7 @@ llama_model_params llama_model_default_params() {
|
|||
/*.use_mlock =*/ false,
|
||||
/*.check_tensors =*/ false,
|
||||
/*.use_extra_bufts =*/ true,
|
||||
/*.no_host =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
@ -20098,6 +20151,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_OPENAI_MOE:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ enum llm_type {
|
|||
LLM_TYPE_17B_16E, // llama4 Scout
|
||||
LLM_TYPE_17B_128E, // llama4 Maverick
|
||||
LLM_TYPE_A13B,
|
||||
LLM_TYPE_8B_A1B, // lfm2moe
|
||||
LLM_TYPE_21B_A3B, // Ernie MoE small
|
||||
LLM_TYPE_30B_A3B,
|
||||
LLM_TYPE_106B_A12B, // GLM-4.5-Air
|
||||
|
|
|
|||
|
|
@ -249,10 +249,9 @@ struct mtmd_context {
|
|||
} else if (proj == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
|
||||
tok_ov_img_start = {lookup_token("\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
|
||||
tok_ov_img_start = {lookup_token("\n\n"), lookup_token("<fake_token_around_image>"), lookup_token("<global-img>")};
|
||||
tok_ov_img_end = {lookup_token("<fake_token_around_image>")};
|
||||
tok_row_end = {lookup_token("\n")};
|
||||
img_beg = "<fake_token_around_image>";
|
||||
sli_img_start_tmpl = "<fake_token_around_image><row_%d_col_%d>";
|
||||
|
||||
} else if (proj == PROJECTOR_TYPE_PIXTRAL) {
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1937,7 +1937,7 @@ private:
|
|||
void cleanup_pending_task(int id_target) {
|
||||
// no need lock because this is called exclusively by post()
|
||||
auto rm_func = [id_target](const server_task & task) {
|
||||
return task.id_target == id_target;
|
||||
return task.id == id_target;
|
||||
};
|
||||
queue_tasks.erase(
|
||||
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
||||
|
|
@ -3676,6 +3676,20 @@ struct server_context {
|
|||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
|
|
@ -3700,6 +3714,11 @@ struct server_context {
|
|||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
|
|
@ -3730,6 +3749,39 @@ struct server_context {
|
|||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
|
||||
|
||||
if (do_checkpoint) {
|
||||
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.ctx_checkpoints.front();
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3853,40 +3905,6 @@ struct server_context {
|
|||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
const bool do_checkpoint =
|
||||
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
|
||||
|
||||
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
|
||||
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.ctx_checkpoints.front();
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
|
||||
}
|
||||
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
|
||||
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
|
||||
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
|
@ -4184,6 +4202,7 @@ int main(int argc, char ** argv) {
|
|||
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
static const std::unordered_set<std::string> public_endpoints = {
|
||||
"/health",
|
||||
"/v1/health",
|
||||
"/models",
|
||||
"/v1/models",
|
||||
"/api/tags"
|
||||
|
|
@ -5232,6 +5251,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// register API routes
|
||||
svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check)
|
||||
svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check)
|
||||
svr->Get (params.api_prefix + "/metrics", handle_metrics);
|
||||
svr->Get (params.api_prefix + "/props", handle_props);
|
||||
svr->Post(params.api_prefix + "/props", handle_props_change);
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
<script lang="ts">
|
||||
import { Search, SquarePen, X } from '@lucide/svelte';
|
||||
import { Search, SquarePen, X, Download, Upload } from '@lucide/svelte';
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { exportAllConversations, importConversations } from '$lib/stores/chat.svelte';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
|
|
@ -77,5 +78,34 @@
|
|||
|
||||
<KeyboardShortcutInfo keys={['cmd', 'k']} />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start text-sm"
|
||||
onclick={() => {
|
||||
importConversations().catch((err) => {
|
||||
console.error('Import failed:', err);
|
||||
// Optional: show toast or dialog
|
||||
});
|
||||
}}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<Upload class="h-4 w-4" />
|
||||
Import conversations
|
||||
</div>
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
class="w-full justify-start text-sm"
|
||||
onclick={() => {
|
||||
exportAllConversations();
|
||||
}}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<Download class="h-4 w-4" />
|
||||
Export all conversations
|
||||
</div>
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
<script lang="ts">
|
||||
import { Trash2, Pencil, MoreHorizontal } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, MoreHorizontal, Download } from '@lucide/svelte';
|
||||
import { ActionDropdown } from '$lib/components/app';
|
||||
import { downloadConversation } from '$lib/stores/chat.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -101,6 +102,15 @@
|
|||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e) => {
|
||||
e.stopPropagation();
|
||||
downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import { filterByLeafNodeId, findLeafNode, findDescendantMessages } from '$lib/u
|
|||
import { browser } from '$app/environment';
|
||||
import { goto } from '$app/navigation';
|
||||
import { extractPartialThinking } from '$lib/utils/thinking';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import type { ExportedConversations } from '$lib/types/database';
|
||||
|
||||
/**
|
||||
* ChatStore - Central state management for chat conversations and AI interactions
|
||||
|
|
@ -951,6 +953,166 @@ class ChatStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads a conversation as JSON file
|
||||
* @param convId - The conversation ID to download
|
||||
*/
|
||||
async downloadConversation(convId: string): Promise<void> {
|
||||
if (!this.activeConversation || this.activeConversation.id !== convId) {
|
||||
// Load the conversation if not currently active
|
||||
const conversation = await DatabaseStore.getConversation(convId);
|
||||
if (!conversation) return;
|
||||
|
||||
const messages = await DatabaseStore.getConversationMessages(convId);
|
||||
const conversationData = {
|
||||
conv: conversation,
|
||||
messages
|
||||
};
|
||||
|
||||
this.triggerDownload(conversationData);
|
||||
} else {
|
||||
// Use current active conversation data
|
||||
const conversationData: ExportedConversations = {
|
||||
conv: this.activeConversation!,
|
||||
messages: this.activeMessages
|
||||
};
|
||||
|
||||
this.triggerDownload(conversationData);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Triggers file download in browser
|
||||
* @param data - Data to download (expected: { conv: DatabaseConversation, messages: DatabaseMessage[] })
|
||||
* @param filename - Optional filename
|
||||
*/
|
||||
private triggerDownload(data: ExportedConversations, filename?: string): void {
|
||||
const conversation =
|
||||
'conv' in data ? data.conv : Array.isArray(data) ? data[0]?.conv : undefined;
|
||||
if (!conversation) {
|
||||
console.error('Invalid data: missing conversation');
|
||||
return;
|
||||
}
|
||||
const conversationName = conversation.name ? conversation.name.trim() : '';
|
||||
const convId = conversation.id || 'unknown';
|
||||
const truncatedSuffix = conversationName
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9]/gi, '_')
|
||||
.replace(/_+/g, '_')
|
||||
.substring(0, 20);
|
||||
const downloadFilename = filename || `conversation_${convId}_${truncatedSuffix}.json`;
|
||||
|
||||
const conversationJson = JSON.stringify(data, null, 2);
|
||||
const blob = new Blob([conversationJson], {
|
||||
type: 'application/json'
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = downloadFilename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
/**
|
||||
* Exports all conversations with their messages as a JSON file
|
||||
*/
|
||||
async exportAllConversations(): Promise<void> {
|
||||
try {
|
||||
const allConversations = await DatabaseStore.getAllConversations();
|
||||
if (allConversations.length === 0) {
|
||||
throw new Error('No conversations to export');
|
||||
}
|
||||
|
||||
const allData: ExportedConversations = await Promise.all(
|
||||
allConversations.map(async (conv) => {
|
||||
const messages = await DatabaseStore.getConversationMessages(conv.id);
|
||||
return { conv, messages };
|
||||
})
|
||||
);
|
||||
|
||||
const blob = new Blob([JSON.stringify(allData, null, 2)], {
|
||||
type: 'application/json'
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `all_conversations_${new Date().toISOString().split('T')[0]}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
toast.success(`All conversations (${allConversations.length}) prepared for download`);
|
||||
} catch (err) {
|
||||
console.error('Failed to export conversations:', err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports conversations from a JSON file.
|
||||
* Supports both single conversation (object) and multiple conversations (array).
|
||||
* Uses DatabaseStore for safe, encapsulated data access
|
||||
*/
|
||||
async importConversations(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const input = document.createElement('input');
|
||||
input.type = 'file';
|
||||
input.accept = '.json';
|
||||
|
||||
input.onchange = async (e) => {
|
||||
const file = (e.target as HTMLInputElement)?.files?.[0];
|
||||
if (!file) {
|
||||
reject(new Error('No file selected'));
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const text = await file.text();
|
||||
const parsedData = JSON.parse(text);
|
||||
let importedData: ExportedConversations;
|
||||
|
||||
if (Array.isArray(parsedData)) {
|
||||
importedData = parsedData;
|
||||
} else if (
|
||||
parsedData &&
|
||||
typeof parsedData === 'object' &&
|
||||
'conv' in parsedData &&
|
||||
'messages' in parsedData
|
||||
) {
|
||||
// Single conversation object
|
||||
importedData = [parsedData];
|
||||
} else {
|
||||
throw new Error(
|
||||
'Invalid file format: expected array of conversations or single conversation object'
|
||||
);
|
||||
}
|
||||
|
||||
const result = await DatabaseStore.importConversations(importedData);
|
||||
|
||||
// Refresh UI
|
||||
await this.loadConversations();
|
||||
|
||||
toast.success(`Imported ${result.imported} conversation(s), skipped ${result.skipped}`);
|
||||
|
||||
resolve(undefined);
|
||||
} catch (err: unknown) {
|
||||
const message = err instanceof Error ? err.message : 'Unknown error';
|
||||
console.error('Failed to import conversations:', err);
|
||||
toast.error('Import failed', {
|
||||
description: message
|
||||
});
|
||||
reject(new Error(`Import failed: ${message}`));
|
||||
}
|
||||
};
|
||||
|
||||
input.click();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a conversation and all its messages
|
||||
* @param convId - The conversation ID to delete
|
||||
|
|
@ -1427,6 +1589,9 @@ export const isInitialized = () => chatStore.isInitialized;
|
|||
export const maxContextError = () => chatStore.maxContextError;
|
||||
|
||||
export const createConversation = chatStore.createConversation.bind(chatStore);
|
||||
export const downloadConversation = chatStore.downloadConversation.bind(chatStore);
|
||||
export const exportAllConversations = chatStore.exportAllConversations.bind(chatStore);
|
||||
export const importConversations = chatStore.importConversations.bind(chatStore);
|
||||
export const deleteConversation = chatStore.deleteConversation.bind(chatStore);
|
||||
export const sendMessage = chatStore.sendMessage.bind(chatStore);
|
||||
export const gracefulStop = chatStore.gracefulStop.bind(chatStore);
|
||||
|
|
|
|||
|
|
@ -346,4 +346,39 @@ export class DatabaseStore {
|
|||
): Promise<void> {
|
||||
await db.messages.update(id, updates);
|
||||
}
|
||||
|
||||
/**
|
||||
* Imports multiple conversations and their messages.
|
||||
* Skips conversations that already exist.
|
||||
*
|
||||
* @param data - Array of { conv, messages } objects
|
||||
*/
|
||||
static async importConversations(
|
||||
data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[]
|
||||
): Promise<{ imported: number; skipped: number }> {
|
||||
let importedCount = 0;
|
||||
let skippedCount = 0;
|
||||
|
||||
return await db.transaction('rw', [db.conversations, db.messages], async () => {
|
||||
for (const item of data) {
|
||||
const { conv, messages } = item;
|
||||
|
||||
const existing = await db.conversations.get(conv.id);
|
||||
if (existing) {
|
||||
console.warn(`Conversation "${conv.name}" already exists, skipping...`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
await db.conversations.add(conv);
|
||||
for (const msg of messages) {
|
||||
await db.messages.put(msg);
|
||||
}
|
||||
|
||||
importedCount++;
|
||||
}
|
||||
|
||||
return { imported: importedCount, skipped: skippedCount };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
15
tools/server/webui/src/lib/types/database.d.ts
vendored
15
tools/server/webui/src/lib/types/database.d.ts
vendored
|
|
@ -54,3 +54,18 @@ export interface DatabaseMessage {
|
|||
timings?: ChatMessageTimings;
|
||||
model?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a single conversation with its associated messages,
|
||||
* typically used for import/export operations.
|
||||
*/
|
||||
export type ExportedConversation = {
|
||||
conv: DatabaseConversation;
|
||||
messages: DatabaseMessage[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Type representing one or more exported conversations.
|
||||
* Can be a single conversation object or an array of them.
|
||||
*/
|
||||
export type ExportedConversations = ExportedConversation | ExportedConversation[];
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue