diff --git a/common/arg.cpp b/common/arg.cpp index 2a817d6ea..b78b74b8c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2607,6 +2607,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.no_extra_bufts = true; } ).set_env("LLAMA_ARG_NO_REPACK")); + add_opt(common_arg( + {"--no-host"}, + "bypass host buffer allowing extra buffers to be used", + [](common_params & params) { + params.no_host = true; + } + ).set_env("LLAMA_ARG_NO_HOST")); add_opt(common_arg( {"-ctk", "--cache-type-k"}, "TYPE", string_format( @@ -3875,7 +3882,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3889,7 +3895,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; params.model.hf_file = "e5-small-v2-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3903,7 +3908,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; params.model.hf_file = "gte-small-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; diff --git a/common/common.cpp b/common/common.cpp index e7d3b4df7..cfe707e3e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1141,6 +1141,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; + mparams.no_host = params.no_host; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index 57739cd80..01d5ed0e3 100644 --- a/common/common.h +++ b/common/common.h @@ -388,6 +388,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool no_host = false; // bypass host buffer allowing extra buffers to be used bool single_turn = false; // single turn chat conversation diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a96389ab7..15edb59f0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8841,6 +8841,75 @@ class LFM2Model(TextModel): return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Lfm2MoeForCausalLM") +class LFM2MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.LFM2MOE + + def set_gguf_parameters(self): + # set num_key_value_heads only for attention layers + self.hparams["num_key_value_heads"] = [ + self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + for layer_type in self.hparams["layer_types"] + ] + + super().set_gguf_parameters() + + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"]) + + # cache for experts weights for merging + _experts_cache: dict[int, dict[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # conv op requires 2d tensor + if 'conv.conv' in name: + data_torch = data_torch.squeeze(1) + + if name.endswith(".expert_bias"): + name = name.replace(".expert_bias", ".expert_bias.bias") + + # merge expert weights + if 'experts' in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + expert_cache = self._experts_cache.setdefault(bid, {}) + expert_cache[name] = data_torch + expert_weights = ["w1", "w2", "w3"] + + # not enough expert weights to merge + if len(expert_cache) < n_experts * len(expert_weights): + return [] + + tensors: list[tuple[str, Tensor]] = [] + for w_name in expert_weights: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight" + datas.append(expert_cache[ename]) + del expert_cache[ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + del self._experts_cache[bid] + return tensors + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + assert not self._experts_cache + + @ModelBase.register("Lfm2VlForConditionalGeneration") class LFM2VLModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6275c8305..8e1a2de14 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 54f4fb244..98e0377c3 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int cc = ggml_cuda_info().devices[device].cc; + // TODO: temporary until support is extended + // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 + if (K->ne[1] % FATTN_KQ_STRIDE != 0) { + return BEST_FATTN_KERNEL_NONE; + } + #if defined(GGML_HIP_ROCWMMA_FATTN) if (GGML_CUDA_CC_IS_AMD(cc) && ggml_cuda_should_use_wmma_fattn(cc)) { //kcpp: fix for rocwmma return BEST_FATTN_KERNEL_WMMA_F16; diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8a..e23abdda9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar char base[256]; char name[256]; - snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); @@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + char base[256]; char name[256]; - if (op->src[3]->ne[0] == 1) { - snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); - } else { - snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); - } - snprintf(name, 256, "%s", base); + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { @@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); return res; } @@ -918,6 +924,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, @@ -925,6 +1021,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -937,18 +1034,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", "flash_attn_ext", ggml_type_name(op->src[1]->type), dk, dv); - snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, + bc_mask, ns10, ns20, nsg); @@ -964,6 +1066,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); @@ -983,6 +1088,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1002,12 +1108,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( dk, dv); - snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, ns10, ns20, nsg, nwg); @@ -1023,6 +1130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a0..1034e4bbf 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -135,6 +135,18 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, @@ -142,6 +154,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( @@ -151,6 +164,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index ec15815bf..cf13901cb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423e..c9dff8730 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -69,11 +69,20 @@ #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -178,6 +187,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -243,6 +253,35 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; @@ -261,6 +300,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -295,6 +335,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -572,32 +613,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c0..1137e2107 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); GGML_TENSOR_LOCALS( int64_t, ne, node, ne); GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); @@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ggml_is_contiguous(node->src[1]), node->src[1]->name); } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } if (node) { GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, node->name); @@ -577,6 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -906,23 +919,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, }; + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); return 1; } @@ -1117,7 +1138,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); @@ -1172,25 +1193,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, /*.nb31 =*/ nb31, /*.nb41 =*/ nb41, /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, /*.nb43 =*/ nb43, /*.nb51 =*/ nb51, /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -1206,13 +1238,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); - if (ne30 == 1) { - // Mamba-2 - ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); - } else { - GGML_ASSERT(d_inner == 1); - ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); - } + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); return 1; } @@ -1273,26 +1299,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(op->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); } - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; // TODO: relax this constraint in the future if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nrptg--; @@ -1300,10 +1323,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } } - nth = std::min(nth, nk00); + nth = std::min(nth, nk0); ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -1321,12 +1345,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); return 1; } @@ -1875,20 +1901,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { return (ne01 < 20) && (ne00 % 32 == 0); } +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); - const int64_t nwg = 32; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); - const int64_t ne01 = op->src[0]->ne[1]; - const int64_t ne02 = op->src[0]->ne[2]; - const int64_t ne03 = op->src[0]->ne[3]; - const int64_t ne20 = op->src[2]->ne[0]; + size_t res = 0; - // temp buffer for writing the results from each workgroup - // - ne20: the size of the Value head - // - + 2: the S and M values for each intermediate result - return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { @@ -1910,8 +2023,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, nb, op, nb); - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == op->src[2]->type); @@ -1921,8 +2033,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne12 == ne22); GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); - GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); float scale; float max_bias; @@ -1949,15 +2061,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2007,6 +2215,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2023,24 +2232,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2048,14 +2251,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2120,6 +2371,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2136,25 +2388,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2162,23 +2406,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op); - ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e7..d4cb94462 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); // return true if we should use the FA vector kernel for this op bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e11555a78..7afc881fa 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ } break; case GGML_OP_FLASH_ATTN_EXT: { - if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { - res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); - } + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 96df6f0ce..45d91def8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -4449,10 +4349,142 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -4499,6 +4531,8 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4564,6 +4598,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -4623,16 +4664,75 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -4651,13 +4751,14 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -4668,26 +4769,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -4696,26 +4795,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -4729,7 +4824,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -4849,27 +4944,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; - pv += 8*(NS20 - NO*NSG); - sst += 8; + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -4893,7 +5011,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4983,7 +5101,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -5028,8 +5146,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -5037,13 +5155,15 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5163,6 +5283,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -5189,9 +5310,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -5200,6 +5321,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -5305,12 +5427,38 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -5322,7 +5470,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -5337,7 +5485,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -5435,7 +5583,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -5448,7 +5596,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -5573,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -5611,8 +5759,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -5620,13 +5768,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; @@ -5750,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4; @@ -5770,21 +5920,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5795,190 +5941,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -5986,11 +6012,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -6002,10 +6029,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -7765,66 +7794,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -8310,12 +8333,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ec5202885..9c99b90fa 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -407,6 +407,7 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() GPT_OSS = auto() LFM2 = auto() + LFM2MOE = auto() DREAM = auto() SMALLTHINKER = auto() LLADA = auto() @@ -749,6 +750,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", + MODEL_ARCH.LFM2MOE: "lfm2moe", MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", @@ -2698,6 +2700,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.OUTPUT, ], + MODEL_ARCH.LFM2MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.SHORTCONV_CONV, + MODEL_TENSOR.SHORTCONV_INPROJ, + MODEL_TENSOR.SHORTCONV_OUTPROJ, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.ATTN_NORM, # operator_norm + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.SMALLTHINKER: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 67b274134..3e9a2dd8f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -358,6 +358,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.router", # openai-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker + "model.layers.{bid}.feed_forward.gate", # lfm2moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -367,6 +368,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_EXP_PROBS_B: ( "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe + "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe ), # Feed-forward up diff --git a/include/llama.h b/include/llama.h index 26e93366f..36969fd37 100644 --- a/include/llama.h +++ b/include/llama.h @@ -299,6 +299,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4fd083aa0..45f0d0e2c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -93,6 +93,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, @@ -2104,6 +2105,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2493,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index bc4b04bb4..507fe5f37 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -97,6 +97,7 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 66e6c6a38..956c4e085 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index cb8832a35..dfb8439e0 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 44fced074..5bc5e58b4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -859,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4b2ff88f0..a229d148a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -119,6 +119,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -315,7 +316,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -336,11 +337,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -1998,14 +2001,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2088,7 +2106,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -5870,6 +5888,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5881,11 +5900,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6367,7 +6398,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -18662,6 +18693,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18676,7 +18709,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18691,23 +18733,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19877,6 +19928,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19927,6 +19979,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -20098,6 +20151,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: diff --git a/src/llama-model.h b/src/llama-model.h index eec564e70..20b59d952 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -107,6 +107,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index ff13874cd..4d487581a 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -249,10 +249,9 @@ struct mtmd_context { } else if (proj == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3; - tok_ov_img_start = {lookup_token("\n"), lookup_token(""), lookup_token("")}; + tok_ov_img_start = {lookup_token("\n\n"), lookup_token(""), lookup_token("")}; tok_ov_img_end = {lookup_token("")}; tok_row_end = {lookup_token("\n")}; - img_beg = ""; sli_img_start_tmpl = ""; } else if (proj == PROJECTOR_TYPE_PIXTRAL) { diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 2801319c9..8d57b4a16 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a21147613..de6e1a322 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1937,7 +1937,7 @@ private: void cleanup_pending_task(int id_target) { // no need lock because this is called exclusively by post() auto rm_func = [id_target](const server_task & task) { - return task.id_target == id_target; + return task.id == id_target; }; queue_tasks.erase( std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), @@ -3676,6 +3676,20 @@ struct server_context { alora_disabled_id = enabled_loras[0]; } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3700,6 +3714,11 @@ struct server_context { slot.n_prompt_tokens_processed++; slot.n_past++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { + break; + } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); @@ -3730,6 +3749,39 @@ struct server_context { slot.i_batch = batch.n_tokens - 1; SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.ctx_checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } } } @@ -3853,40 +3905,6 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - const bool do_checkpoint = - (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full); - - if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ - /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), - /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } @@ -4184,6 +4202,7 @@ int main(int argc, char ** argv) { auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", + "/v1/health", "/models", "/v1/models", "/api/tags" @@ -5232,6 +5251,7 @@ int main(int argc, char ** argv) { // register API routes svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) svr->Get (params.api_prefix + "/metrics", handle_metrics); svr->Get (params.api_prefix + "/props", handle_props); svr->Post(params.api_prefix + "/props", handle_props_change); diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte index 30d1f9d4b..e91673e98 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte @@ -1,8 +1,9 @@